diff --git a/docs/apis/visualization.md b/docs/apis/visualization.md
index 10349aaf240..7815ec568b8 100644
--- a/docs/apis/visualization.md
+++ b/docs/apis/visualization.md
@@ -19,3 +19,22 @@ For a detailed tutorial, please refer to our [Visualization Tutorial](../tutoria
:undoc-members:
:show-inheritance:
```
+
+
+## Matplotlib-based components
+
+```{eval-rst}
+.. automodule:: mesa.visualization.components.matplotlib
+ :members:
+ :undoc-members:
+ :show-inheritance:
+```
+
+## Altair-based components
+
+```{eval-rst}
+.. automodule:: mesa.visualization.components.altair
+ :members:
+ :undoc-members:
+ :show-inheritance:
+```
\ No newline at end of file
diff --git a/docs/migration_guide.md b/docs/migration_guide.md
index b4f54c3dfd7..016053e84fb 100644
--- a/docs/migration_guide.md
+++ b/docs/migration_guide.md
@@ -268,9 +268,9 @@ from mesa.experimental import SolaraViz
SolaraViz(model_cls, model_params, agent_portrayal=agent_portrayal)
# new
-from mesa.visualization import SolaraViz, make_space_matplotlib
+from mesa.visualization import SolaraViz, make_space_component
-SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)])
+SolaraViz(model, components=[make_space_component(agent_portrayal)])
```
#### Plotting "measures"
diff --git a/docs/overview.md b/docs/overview.md
index 17ef67302fe..7d0c750ed84 100644
--- a/docs/overview.md
+++ b/docs/overview.md
@@ -168,11 +168,13 @@ The results are returned as a list of dictionaries, which can be easily converte
Mesa now uses a new browser-based visualization system called SolaraViz. This allows for interactive, customizable visualizations of your models. Here's a basic example of how to set up a visualization:
```python
-from mesa.visualization import SolaraViz, make_space_matplotlib, make_plot_measure
+from mesa.visualization import SolaraViz, make_space_component, make_plot_measure
+
def agent_portrayal(agent):
return {"color": "blue", "size": 50}
+
model_params = {
"N": {
"type": "SliderInt",
@@ -187,7 +189,7 @@ model_params = {
page = SolaraViz(
MyModel,
[
- make_space_matplotlib(agent_portrayal),
+ make_space_component(agent_portrayal),
make_plot_measure("mean_age")
],
model_params=model_params
diff --git a/docs/tutorials/visualization_tutorial.ipynb b/docs/tutorials/visualization_tutorial.ipynb
index 7460ca30a19..aff166e9b82 100644
--- a/docs/tutorials/visualization_tutorial.ipynb
+++ b/docs/tutorials/visualization_tutorial.ipynb
@@ -3,9 +3,7 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": [
- "# Visualization Tutorial"
- ]
+ "source": "# Visualization Tutorial"
},
{
"cell_type": "markdown",
@@ -52,40 +50,50 @@
},
{
"cell_type": "code",
- "execution_count": 1,
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:46.075682Z",
+ "start_time": "2024-10-29T19:38:45.449918Z"
+ }
+ },
+ "source": [
+ "import mesa\n",
+ "print(f\"Mesa version: {mesa.__version__}\")\n",
+ "\n",
+ "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n",
+ "\n",
+ "# Import the local MoneyModel.py\n",
+ "from MoneyModel import MoneyModel\n"
+ ],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Mesa version: 3.0.0b1\n"
+ "Mesa version: 3.0.0b2\n"
]
}
],
- "source": [
- "import mesa\n",
- "print(f\"Mesa version: {mesa.__version__}\")\n",
- "\n",
- "from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib\n",
- "# Import the local MoneyModel.py\n",
- "from MoneyModel import MoneyModel\n"
- ]
+ "execution_count": 1
},
{
"cell_type": "code",
- "execution_count": null,
"metadata": {
- "tags": []
+ "tags": [],
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:46.079286Z",
+ "start_time": "2024-10-29T19:38:46.076984Z"
+ }
},
- "outputs": [],
"source": [
"def agent_portrayal(agent):\n",
" return {\n",
" \"color\": \"tab:blue\",\n",
" \"size\": 50,\n",
" }"
- ]
+ ],
+ "outputs": [],
+ "execution_count": 2
},
{
"cell_type": "markdown",
@@ -96,9 +104,12 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:46.081662Z",
+ "start_time": "2024-10-29T19:38:46.079838Z"
+ }
+ },
"source": [
"model_params = {\n",
" \"n\": {\n",
@@ -112,7 +123,9 @@
" \"width\": 10,\n",
" \"height\": 10,\n",
"}"
- ]
+ ],
+ "outputs": [],
+ "execution_count": 3
},
{
"cell_type": "markdown",
@@ -130,16 +143,18 @@
},
{
"cell_type": "code",
- "execution_count": null,
"metadata": {
- "tags": []
+ "tags": [],
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:46.864371Z",
+ "start_time": "2024-10-29T19:38:46.082810Z"
+ }
},
- "outputs": [],
"source": [
"# Create initial model instance\n",
"model1 = MoneyModel(50, 10, 10)\n",
"\n",
- "SpaceGraph = make_space_matplotlib(agent_portrayal)\n",
+ "SpaceGraph = make_space_component(agent_portrayal)\n",
"GiniPlot = make_plot_measure(\"Gini\")\n",
"\n",
"page = SolaraViz(\n",
@@ -150,7 +165,27 @@
")\n",
"# This is required to render the visualization in the Jupyter notebook\n",
"page"
- ]
+ ],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Cannot show ipywidgets in text"
+ ],
+ "text/html": [
+ "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter ⇧+↩)."
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "c9f2ef2b5a24483c92fa129213414a2c"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 4
},
{
"cell_type": "markdown",
@@ -169,23 +204,39 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:46.867576Z",
+ "start_time": "2024-10-29T19:38:46.865205Z"
+ }
+ },
"source": [
"import mesa\n",
"print(f\"Mesa version: {mesa.__version__}\")\n",
"\n",
- "from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib\n",
+ "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n",
"# Import the local MoneyModel.py\n",
"from MoneyModel import MoneyModel\n"
- ]
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mesa version: 3.0.0b2\n"
+ ]
+ }
+ ],
+ "execution_count": 5
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:46.870617Z",
+ "start_time": "2024-10-29T19:38:46.868336Z"
+ }
+ },
"source": [
"def agent_portrayal(agent):\n",
" size = 10\n",
@@ -207,18 +258,23 @@
" \"width\": 10,\n",
" \"height\": 10,\n",
"}"
- ]
+ ],
+ "outputs": [],
+ "execution_count": 6
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:47.881911Z",
+ "start_time": "2024-10-29T19:38:46.871328Z"
+ }
+ },
"source": [
"# Create initial model instance\n",
"model1 = MoneyModel(50, 10, 10)\n",
"\n",
- "SpaceGraph = make_space_matplotlib(agent_portrayal)\n",
+ "SpaceGraph = make_space_component(agent_portrayal)\n",
"GiniPlot = make_plot_measure(\"Gini\")\n",
"\n",
"page = SolaraViz(\n",
@@ -229,7 +285,27 @@
")\n",
"# This is required to render the visualization in the Jupyter notebook\n",
"page"
- ]
+ ],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Cannot show ipywidgets in text"
+ ],
+ "text/html": [
+ "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter ⇧+↩)."
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "da8518ec9ce74c068288bec0c8d3793e"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 7
},
{
"cell_type": "markdown",
@@ -250,9 +326,12 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:47.885386Z",
+ "start_time": "2024-10-29T19:38:47.882808Z"
+ }
+ },
"source": [
"import mesa\n",
"print(f\"Mesa version: {mesa.__version__}\")\n",
@@ -260,16 +339,29 @@
"from matplotlib.figure import Figure\n",
"\n",
"from mesa.visualization.utils import update_counter\n",
- "from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib\n",
+ "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n",
"# Import the local MoneyModel.py\n",
"from MoneyModel import MoneyModel\n"
- ]
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mesa version: 3.0.0b2\n"
+ ]
+ }
+ ],
+ "execution_count": 8
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:47.888491Z",
+ "start_time": "2024-10-29T19:38:47.886217Z"
+ }
+ },
"source": [
"def agent_portrayal(agent):\n",
" size = 10\n",
@@ -291,7 +383,9 @@
" \"width\": 10,\n",
" \"height\": 10,\n",
"}"
- ]
+ ],
+ "outputs": [],
+ "execution_count": 9
},
{
"cell_type": "markdown",
@@ -302,9 +396,12 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:47.893643Z",
+ "start_time": "2024-10-29T19:38:47.891084Z"
+ }
+ },
"source": [
"@solara.component\n",
"def Histogram(model):\n",
@@ -318,26 +415,36 @@
" # because plt.hist is not thread-safe.\n",
" ax.hist(wealth_vals, bins=10)\n",
" solara.FigureMatplotlib(fig)"
- ]
+ ],
+ "outputs": [],
+ "execution_count": 10
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:47.896565Z",
+ "start_time": "2024-10-29T19:38:47.894387Z"
+ }
+ },
"source": [
"# Create initial model instance\n",
"model1 = MoneyModel(50, 10, 10)\n",
"\n",
- "SpaceGraph = make_space_matplotlib(agent_portrayal)\n",
+ "SpaceGraph = make_space_component(agent_portrayal)\n",
"GiniPlot = make_plot_measure(\"Gini\")"
- ]
+ ],
+ "outputs": [],
+ "execution_count": 11
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:49.471838Z",
+ "start_time": "2024-10-29T19:38:47.897295Z"
+ }
+ },
"source": [
"page = SolaraViz(\n",
" model1,\n",
@@ -347,7 +454,27 @@
")\n",
"# This is required to render the visualization in the Jupyter notebook\n",
"page"
- ]
+ ],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Cannot show ipywidgets in text"
+ ],
+ "text/html": [
+ "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter ⇧+↩)."
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "bc71b89ee5684038a194eee4c36f4a4c"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 12
},
{
"cell_type": "markdown",
@@ -358,12 +485,35 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-10-29T19:38:49.505725Z",
+ "start_time": "2024-10-29T19:38:49.472599Z"
+ }
+ },
"source": [
"Histogram(model1)"
- ]
+ ],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Cannot show ipywidgets in text"
+ ],
+ "text/html": [
+ "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter ⇧+↩)."
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "0491f167a1434a92b78535078bd082a8"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 13
},
{
"cell_type": "markdown",
diff --git a/mesa/examples/advanced/epstein_civil_violence/app.py b/mesa/examples/advanced/epstein_civil_violence/app.py
index 862ca6220d8..538ef186f57 100644
--- a/mesa/examples/advanced/epstein_civil_violence/app.py
+++ b/mesa/examples/advanced/epstein_civil_violence/app.py
@@ -8,7 +8,7 @@
Slider,
SolaraViz,
make_plot_measure,
- make_space_matplotlib,
+ make_space_component,
)
COP_COLOR = "#000000"
@@ -47,7 +47,7 @@ def citizen_cop_portrayal(agent):
"max_jail_term": Slider("Max Jail Term", 30, 0, 50, 1),
}
-space_component = make_space_matplotlib(citizen_cop_portrayal)
+space_component = make_space_component(citizen_cop_portrayal)
chart_component = make_plot_measure([state.name.lower() for state in CitizenState])
epstein_model = EpsteinCivilViolence()
diff --git a/mesa/examples/advanced/pd_grid/app.py b/mesa/examples/advanced/pd_grid/app.py
index c8ceec9fe16..6edf8140536 100644
--- a/mesa/examples/advanced/pd_grid/app.py
+++ b/mesa/examples/advanced/pd_grid/app.py
@@ -3,7 +3,7 @@
"""
from mesa.examples.advanced.pd_grid.model import PdGrid
-from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib
+from mesa.visualization import SolaraViz, make_plot_measure, make_space_component
from mesa.visualization.UserParam import Slider
@@ -32,7 +32,7 @@ def pd_agent_portrayal(agent):
# Create grid visualization component using Altair
-grid_viz = make_space_matplotlib(agent_portrayal=pd_agent_portrayal)
+grid_viz = make_space_component(agent_portrayal=pd_agent_portrayal)
# Create plot for tracking cooperating agents over time
plot_component = make_plot_measure("Cooperating_Agents")
diff --git a/mesa/examples/advanced/sugarscape_g1mt/app.py b/mesa/examples/advanced/sugarscape_g1mt/app.py
index 7c8cc2cfead..752998891bd 100644
--- a/mesa/examples/advanced/sugarscape_g1mt/app.py
+++ b/mesa/examples/advanced/sugarscape_g1mt/app.py
@@ -1,11 +1,3 @@
-import os.path
-import sys
-
-sys.path.insert(
- 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../"))
-)
-
-
import numpy as np
import solara
from matplotlib.figure import Figure
diff --git a/mesa/examples/advanced/wolf_sheep/app.py b/mesa/examples/advanced/wolf_sheep/app.py
index b5ac6e8bf47..a8c0a1e9c49 100644
--- a/mesa/examples/advanced/wolf_sheep/app.py
+++ b/mesa/examples/advanced/wolf_sheep/app.py
@@ -4,12 +4,9 @@
Slider,
SolaraViz,
make_plot_measure,
- make_space_matplotlib,
+ make_space_component,
)
-WOLF_COLOR = "#000000"
-SHEEP_COLOR = "#648FFF"
-
def wolf_sheep_portrayal(agent):
if agent is None:
@@ -17,23 +14,23 @@ def wolf_sheep_portrayal(agent):
portrayal = {
"size": 25,
- "shape": "s", # square marker
}
if isinstance(agent, Wolf):
- portrayal["color"] = WOLF_COLOR
- portrayal["Layer"] = 3
+ portrayal["color"] = "tab:red"
+ portrayal["marker"] = "o"
+ portrayal["zorder"] = 2
elif isinstance(agent, Sheep):
- portrayal["color"] = SHEEP_COLOR
- portrayal["Layer"] = 2
+ portrayal["color"] = "tab:cyan"
+ portrayal["marker"] = "o"
+ portrayal["zorder"] = 2
elif isinstance(agent, GrassPatch):
if agent.fully_grown:
- portrayal["color"] = "#00FF00"
+ portrayal["color"] = "tab:green"
else:
- portrayal["color"] = "#84e184"
- # portrayal["shape"] = "rect"
- # portrayal["Filled"] = "true"
- portrayal["Layer"] = 1
+ portrayal["color"] = "tab:brown"
+ portrayal["marker"] = "s"
+ portrayal["size"] = 75
return portrayal
@@ -62,10 +59,20 @@ def wolf_sheep_portrayal(agent):
}
-space_component = make_space_matplotlib(wolf_sheep_portrayal)
-lineplot_component = make_plot_measure(["Wolves", "Sheep", "Grass"])
+def post_process(ax):
+ ax.set_aspect("equal")
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+
+space_component = make_space_component(
+ wolf_sheep_portrayal, draw_grid=False, post_process=post_process
+)
+lineplot_component = make_plot_measure(
+ {"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"}
+)
-model = WolfSheep()
+model = WolfSheep(grass=True)
page = SolaraViz(
diff --git a/mesa/examples/basic/boid_flockers/app.py b/mesa/examples/basic/boid_flockers/app.py
index 482d582b8ba..bcecb0a3ebd 100644
--- a/mesa/examples/basic/boid_flockers/app.py
+++ b/mesa/examples/basic/boid_flockers/app.py
@@ -1,5 +1,5 @@
from mesa.examples.basic.boid_flockers.model import BoidFlockers
-from mesa.visualization import Slider, SolaraViz, make_space_matplotlib
+from mesa.visualization import Slider, SolaraViz, make_space_component
def boid_draw(agent):
@@ -51,7 +51,7 @@ def boid_draw(agent):
page = SolaraViz(
model,
- [make_space_matplotlib(agent_portrayal=boid_draw)],
+ [make_space_component(agent_portrayal=boid_draw)],
model_params=model_params,
name="Boid Flocking Model",
)
diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py
index 7e3f41e64de..2ab6d06bf73 100644
--- a/mesa/examples/basic/boltzmann_wealth_model/app.py
+++ b/mesa/examples/basic/boltzmann_wealth_model/app.py
@@ -2,7 +2,7 @@
from mesa.visualization import (
SolaraViz,
make_plot_measure,
- make_space_matplotlib,
+ make_space_component,
)
@@ -36,7 +36,7 @@ def agent_portrayal(agent):
# Under the hood these are just classes that receive the model instance.
# You can also author your own visualization elements, which can also be functions
# that receive the model instance and return a valid solara component.
-SpaceGraph = make_space_matplotlib(agent_portrayal)
+SpaceGraph = make_space_component(agent_portrayal)
GiniPlot = make_plot_measure("Gini")
# Create the SolaraViz page. This will automatically create a server and display the
diff --git a/mesa/examples/basic/conways_game_of_life/app.py b/mesa/examples/basic/conways_game_of_life/app.py
index 2c9dace8635..7a45125a30a 100644
--- a/mesa/examples/basic/conways_game_of_life/app.py
+++ b/mesa/examples/basic/conways_game_of_life/app.py
@@ -1,12 +1,12 @@
from mesa.examples.basic.conways_game_of_life.model import ConwaysGameOfLife
from mesa.visualization import (
SolaraViz,
- make_space_matplotlib,
+ make_space_component,
)
def agent_portrayal(agent):
- return {"color": "white" if agent.state == 0 else "black"}
+ return {"c": "white" if agent.state == 0 else "black", "marker": "s"}
model_params = {
@@ -22,7 +22,7 @@ def agent_portrayal(agent):
# Under the hood these are just classes that receive the model instance.
# You can also author your own visualization elements, which can also be functions
# that receive the model instance and return a valid solara component.
-SpaceGraph = make_space_matplotlib(agent_portrayal)
+SpaceGraph = make_space_component(agent_portrayal)
# Create the SolaraViz page. This will automatically create a server and display the
diff --git a/mesa/examples/basic/schelling/app.py b/mesa/examples/basic/schelling/app.py
index 72ae6ddc1ec..53fab7ba0f0 100644
--- a/mesa/examples/basic/schelling/app.py
+++ b/mesa/examples/basic/schelling/app.py
@@ -5,7 +5,7 @@
Slider,
SolaraViz,
make_plot_measure,
- make_space_matplotlib,
+ make_space_component,
)
@@ -33,7 +33,7 @@ def agent_portrayal(agent):
page = SolaraViz(
model1,
components=[
- make_space_matplotlib(agent_portrayal),
+ make_space_component(agent_portrayal),
make_plot_measure("happy"),
get_happy_agents,
],
diff --git a/mesa/examples/basic/virus_on_network/app.py b/mesa/examples/basic/virus_on_network/app.py
index 0183d256790..7cf54f308d5 100644
--- a/mesa/examples/basic/virus_on_network/app.py
+++ b/mesa/examples/basic/virus_on_network/app.py
@@ -9,7 +9,7 @@
VirusOnNetwork,
number_infected,
)
-from mesa.visualization import Slider, SolaraViz, make_space_matplotlib
+from mesa.visualization import Slider, SolaraViz, make_space_component
def agent_portrayal(graph):
@@ -119,7 +119,7 @@ def make_plot(model):
),
}
-SpacePlot = make_space_matplotlib(agent_portrayal)
+SpacePlot = make_space_component(agent_portrayal)
model1 = VirusOnNetwork()
diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py
index d6e50c37e36..0e1875c751c 100644
--- a/mesa/visualization/__init__.py
+++ b/mesa/visualization/__init__.py
@@ -1,7 +1,7 @@
"""Solara based visualization for Mesa models."""
from .components.altair import make_space_altair
-from .components.matplotlib import make_plot_measure, make_space_matplotlib
+from .components.matplotlib import make_plot_measure, make_space_component
from .solara_viz import JupyterViz, SolaraViz
from .UserParam import Slider
@@ -10,6 +10,6 @@
"SolaraViz",
"Slider",
"make_space_altair",
- "make_space_matplotlib",
+ "make_space_component",
"make_plot_measure",
]
diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py
index bea633d6c8b..7e9982a7387 100644
--- a/mesa/visualization/components/matplotlib.py
+++ b/mesa/visualization/components/matplotlib.py
@@ -1,27 +1,63 @@
"""Matplotlib based solara components for visualization MESA spaces and plots."""
+import itertools
+import math
import warnings
+from collections.abc import Callable
+from typing import Any
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import solara
+from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
+from matplotlib.collections import PatchCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.figure import Figure
+from matplotlib.patches import RegularPolygon
import mesa
-from mesa.experimental.cell_space import Grid, VoronoiGrid
-from mesa.space import PropertyLayer
+from mesa.experimental.cell_space import (
+ OrthogonalMooreGrid,
+ OrthogonalVonNeumannGrid,
+ VoronoiGrid,
+)
+from mesa.space import (
+ ContinuousSpace,
+ HexMultiGrid,
+ HexSingleGrid,
+ MultiGrid,
+ NetworkGrid,
+ PropertyLayer,
+ SingleGrid,
+)
from mesa.visualization.utils import update_counter
+# For typing
+OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
+HexGrid = HexSingleGrid | HexMultiGrid | mesa.experimental.cell_space.HexGrid
+Network = NetworkGrid | mesa.experimental.cell_space.Network
-def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None):
+
+def make_space_component(
+ agent_portrayal: Callable | None = None,
+ propertylayer_portrayal: dict | None = None,
+ post_process: Callable | None = None,
+ **space_drawing_kwargs,
+):
"""Create a Matplotlib-based space visualization component.
Args:
- agent_portrayal (function): Function to portray agents
- propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications
+ agent_portrayal: Function to portray agents.
+ propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
+ post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
+ space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
+ the functions for drawing the various spaces for further details.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
+
Returns:
function: A function that creates a SpaceMatplotlib component
@@ -29,10 +65,16 @@ def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None):
if agent_portrayal is None:
def agent_portrayal(a):
- return {"id": a.unique_id}
+ return {}
def MakeSpaceMatplotlib(model):
- return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal)
+ return SpaceMatplotlib(
+ model,
+ agent_portrayal,
+ propertylayer_portrayal,
+ post_process=post_process,
+ **space_drawing_kwargs,
+ )
return MakeSpaceMatplotlib
@@ -43,48 +85,157 @@ def SpaceMatplotlib(
agent_portrayal,
propertylayer_portrayal,
dependencies: list[any] | None = None,
+ post_process: Callable | None = None,
+ **space_drawing_kwargs,
):
"""Create a Matplotlib-based space visualization component."""
update_counter.get()
- space_fig = Figure()
- space_ax = space_fig.subplots()
+
space = getattr(model, "grid", None)
if space is None:
space = getattr(model, "space", None)
+ fig = Figure()
+ ax = fig.add_subplot()
+
+ draw_space(
+ space,
+ agent_portrayal,
+ propertylayer_portrayal=propertylayer_portrayal,
+ ax=ax,
+ post_process=post_process,
+ **space_drawing_kwargs,
+ )
+
+ solara.FigureMatplotlib(
+ fig, format="png", bbox_inches="tight", dependencies=dependencies
+ )
+
+
+def collect_agent_data(
+ space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid,
+ agent_portrayal: Callable,
+ color="tab:blue",
+ size=25,
+ marker="o",
+ zorder: int = 1,
+):
+ """Collect the plotting data for all agents in the space.
+
+ Args:
+ space: The space containing the Agents.
+ agent_portrayal: A callable that is called with the agent and returns a dict
+ color: default color
+ size: default size
+ marker: default marker
+ zorder: default zorder
+
+ agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order),
+ and marker (marker style)
+
+ """
+ arguments = {"s": [], "c": [], "marker": [], "zorder": [], "loc": []}
+
+ for agent in space.agents:
+ portray = agent_portrayal(agent)
+ loc = agent.pos
+ if loc is None:
+ loc = agent.cell.coordinate
+
+ arguments["loc"].append(loc)
+ arguments["s"].append(portray.pop("size", size))
+ arguments["c"].append(portray.pop("color", color))
+ arguments["marker"].append(portray.pop("marker", marker))
+ arguments["zorder"].append(portray.pop("zorder", zorder))
+
+ if len(portray) > 0:
+ ignored_fields = list(portray.keys())
+ msg = ", ".join(ignored_fields)
+ warnings.warn(
+ f"the following fields are not used in agent portrayal and thus ignored: {msg}.",
+ stacklevel=2,
+ )
+
+ return {k: np.asarray(v) for k, v in arguments.items()}
+
+
+def draw_space(
+ space,
+ agent_portrayal: Callable,
+ propertylayer_portrayal: dict | None = None,
+ ax: Axes | None = None,
+ post_process: Callable | None = None,
+ **space_drawing_kwargs,
+):
+ """Draw a Matplotlib-based visualization of the space.
+
+ Args:
+ space: the space of the mesa model
+ agent_portrayal: A callable that returns a dict specifying how to show the agent
+ propertylayer_portrayal: a dict specifying how to show propertylayer(s)
+ ax: the axes upon which to draw the plot
+ post_process: a callable called with the Axes instance
+ postprocess: a user-specified callable to do post-processing called with the Axes instance. This callable
+ can be used for any further fine-tuning of the plot (e.g., changing ticks, etc.)
+ space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space.
+
+ Returns:
+ Returns the Axes object with the plot drawn onto it.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
+
+ """
+ if ax is None:
+ fig, ax = plt.subplots()
+
# https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching
match space:
- case mesa.space._Grid():
- _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
+ case mesa.space._Grid() | OrthogonalMooreGrid() | OrthogonalVonNeumannGrid():
+ draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
+ case HexSingleGrid() | HexMultiGrid() | mesa.experimental.cell_space.HexGrid():
+ draw_hex_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
+ case mesa.space.NetworkGrid() | mesa.experimental.cell_space.Network():
+ draw_network(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
case mesa.space.ContinuousSpace():
- _draw_continuous_space(space, space_ax, agent_portrayal, model)
- case mesa.space.NetworkGrid():
- _draw_network_grid(space, space_ax, agent_portrayal)
+ draw_continuous_space(space, agent_portrayal, ax=ax)
case VoronoiGrid():
- _draw_voronoi(space, space_ax, agent_portrayal)
- case Grid(): # matches OrthogonalMooreGrid, OrthogonalVonNeumannGrid, and Hexgrid
- # fixme add a separate draw method for hexgrids in the future
- _draw_discrete_space_grid(space, space_ax, agent_portrayal)
- case None:
- if propertylayer_portrayal:
- draw_property_layers(space_ax, space, propertylayer_portrayal, model)
+ draw_voroinoi_grid(space, agent_portrayal, ax=ax)
- solara.FigureMatplotlib(
- space_fig, format="png", bbox_inches="tight", dependencies=dependencies
- )
+ if propertylayer_portrayal:
+ draw_property_layers(space, propertylayer_portrayal, ax=ax)
+
+ if post_process is not None:
+ post_process(ax=ax)
+ return ax
-def draw_property_layers(ax, space, propertylayer_portrayal, model):
+
+def draw_property_layers(
+ space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes
+):
"""Draw PropertyLayers on the given axes.
Args:
- ax (matplotlib.axes.Axes): The axes to draw on.
space (mesa.space._Grid): The space containing the PropertyLayers.
- propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications.
- model (mesa.Model): The model instance.
+ propertylayer_portrayal (dict): the key is the name of the layer, the value is a dict with
+ fields specifying how the layer is to be portrayed
+ ax (matplotlib.axes.Axes): The axes to draw on.
+
+ Notes:
+ valid fields in in the inner dict of propertylayer_portrayal are "alpha", "vmin", "vmax", "color" or "colormap", and "colorbar"
+ so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}`
+
"""
+ try:
+ # old style spaces
+ property_layers = space.properties
+ except AttributeError:
+ # new style spaces
+ property_layers = space.property_layers
+
for layer_name, portrayal in propertylayer_portrayal.items():
- layer = getattr(model, layer_name, None)
+ layer = property_layers.get(layer_name, None)
if not isinstance(layer, PropertyLayer):
continue
@@ -116,7 +267,6 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model):
)
im = ax.imshow(
rgba_data.transpose(1, 0, 2),
- extent=(0, width, 0, height),
origin="lower",
)
if colorbar:
@@ -135,7 +285,6 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model):
alpha=alpha,
vmin=vmin,
vmax=vmax,
- extent=(0, width, 0, height),
origin="lower",
)
if colorbar:
@@ -146,131 +295,272 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model):
)
-def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model):
- if propertylayer_portrayal:
- draw_property_layers(space_ax, space, propertylayer_portrayal, model)
+def draw_orthogonal_grid(
+ space: OrthogonalGrid,
+ agent_portrayal: Callable,
+ ax: Axes | None = None,
+ draw_grid: bool = True,
+):
+ """Visualize a orthogonal grid.
+
+ Args:
+ space: the space to visualize
+ agent_portrayal: a callable that is called with the agent and returns a dict
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
+ draw_grid: whether to draw the grid
- agent_data = _get_agent_data(space, agent_portrayal)
+ Returns:
+ Returns the Axes object with the plot drawn onto it.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
- space_ax.set_xlim(0, space.width)
- space_ax.set_ylim(0, space.height)
- _split_and_scatter(agent_data, space_ax)
+ """
+ if ax is None:
+ fig, ax = plt.subplots()
- # Draw grid lines
- for x in range(space.width + 1):
- space_ax.axvline(x, color="gray", linestyle=":")
- for y in range(space.height + 1):
- space_ax.axhline(y, color="gray", linestyle=":")
+ # gather agent data
+ s_default = (180 / max(space.width, space.height)) ** 2
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
+ # plot the agents
+ _scatter(ax, arguments)
-def _get_agent_data(space, agent_portrayal):
- """Helper function to get agent data for visualization."""
- x, y, s, c, m = [], [], [], [], []
- for agents, pos in space.coord_iter():
- if not agents:
- continue
- if not isinstance(agents, list):
- agents = [agents] # noqa PLW2901
- for agent in agents:
- data = agent_portrayal(agent)
- x.append(pos[0] + 0.5) # Center the agent in the cell
- y.append(pos[1] + 0.5) # Center the agent in the cell
- default_size = (180 / max(space.width, space.height)) ** 2
- s.append(data.get("size", default_size))
- c.append(data.get("color", "tab:blue"))
- m.append(data.get("shape", "o"))
- return {"x": x, "y": y, "s": s, "c": c, "m": m}
-
-
-def _split_and_scatter(portray_data, space_ax):
- """Helper function to split and scatter agent data."""
- for marker in set(portray_data["m"]):
- mask = [m == marker for m in portray_data["m"]]
- space_ax.scatter(
- [x for x, show in zip(portray_data["x"], mask) if show],
- [y for y, show in zip(portray_data["y"], mask) if show],
- s=[s for s, show in zip(portray_data["s"], mask) if show],
- c=[c for c, show in zip(portray_data["c"], mask) if show],
- marker=marker,
+ # further styling
+ ax.set_xlim(-0.5, space.width - 0.5)
+ ax.set_ylim(-0.5, space.height - 0.5)
+
+ if draw_grid:
+ # Draw grid lines
+ for x in np.arange(-0.5, space.width - 0.5, 1):
+ ax.axvline(x, color="gray", linestyle=":")
+ for y in np.arange(-0.5, space.height - 0.5, 1):
+ ax.axhline(y, color="gray", linestyle=":")
+
+ return ax
+
+
+def draw_hex_grid(
+ space: HexGrid,
+ agent_portrayal: Callable,
+ ax: Axes | None = None,
+ draw_grid: bool = True,
+):
+ """Visualize a hex grid.
+
+ Args:
+ space: the space to visualize
+ agent_portrayal: a callable that is called with the agent and returns a dict
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
+ draw_grid: whether to draw the grid
+
+ Returns:
+ Returns the Axes object with the plot drawn onto it.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
+
+ """
+ if ax is None:
+ fig, ax = plt.subplots()
+
+ # gather data
+ s_default = (180 / max(space.width, space.height)) ** 2
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
+
+ # for hexgrids we have to go from logical coordinates to visual coordinates
+ # this is a bit messy.
+
+ # give all even rows an offset in the x direction
+ # give all rows an offset in the y direction
+
+ # numbers here are based on a distance of 1 between centers of hexes
+ offset = math.sqrt(0.75)
+
+ loc = arguments["loc"].astype(float)
+
+ logical = np.mod(loc[:, 1], 2) == 0
+ loc[:, 0][logical] += 0.5
+ loc[:, 1] *= offset
+ arguments["loc"] = loc
+
+ # plot the agents
+ _scatter(ax, arguments)
+
+ # further styling and adding of grid
+ ax.set_xlim(-1, space.width + 0.5)
+ ax.set_ylim(-offset, space.height * offset)
+
+ def setup_hexmesh(
+ width,
+ height,
+ ):
+ """Helper function for creating the hexmaesh."""
+ # fixme: this should be done once, rather than in each update
+ # fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset)
+
+ patches = []
+ for x, y in itertools.product(range(width), range(height)):
+ if y % 2 == 0:
+ x += 0.5 # noqa: PLW2901
+ y *= offset # noqa: PLW2901
+ hex = RegularPolygon(
+ (x, y),
+ numVertices=6,
+ radius=math.sqrt(1 / 3),
+ orientation=np.radians(120),
+ )
+ patches.append(hex)
+ mesh = PatchCollection(
+ patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1
+ )
+ return mesh
+
+ if draw_grid:
+ # add grid
+ ax.add_collection(
+ setup_hexmesh(
+ space.width,
+ space.height,
+ )
)
+ return ax
-def _draw_network_grid(space, space_ax, agent_portrayal):
+def draw_network(
+ space: Network,
+ agent_portrayal: Callable,
+ ax: Axes | None = None,
+ draw_grid: bool = True,
+ layout_alg=nx.spring_layout,
+ layout_kwargs=None,
+):
+ """Visualize a network space.
+
+ Args:
+ space: the space to visualize
+ agent_portrayal: a callable that is called with the agent and returns a dict
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
+ draw_grid: whether to draw the grid
+ layout_alg: a networkx layout algorithm or other callable with the same behavior
+ layout_kwargs: a dictionary of keyword arguments for the layout algorithm
+
+ Returns:
+ Returns the Axes object with the plot drawn onto it.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
+
+ """
+ if ax is None:
+ fig, ax = plt.subplots()
+ if layout_kwargs is None:
+ layout_kwargs = {"seed": 0}
+
+ # gather locations for nodes in network
graph = space.G
- pos = nx.spring_layout(graph, seed=0)
- nx.draw(
- graph,
- ax=space_ax,
- pos=pos,
- **agent_portrayal(graph),
- )
+ pos = layout_alg(graph, **layout_kwargs)
+ x, y = list(zip(*pos.values()))
+ xmin, xmax = min(x), max(x)
+ ymin, ymax = min(y), max(y)
+ width = xmax - xmin
+ height = ymax - ymin
+ x_padding = width / 20
+ y_padding = height / 20
-def _draw_continuous_space(space, space_ax, agent_portrayal, model):
- def portray(space):
- x = []
- y = []
- s = [] # size
- c = [] # color
- m = [] # shape
- for agent in space._agent_to_index:
- data = agent_portrayal(agent)
- _x, _y = agent.pos
- x.append(_x)
- y.append(_y)
-
- # This is matplotlib's default marker size
- default_size = 20
- size = data.get("size", default_size)
- s.append(size)
- color = data.get("color", "tab:blue")
- c.append(color)
- mark = data.get("shape", "o")
- m.append(mark)
- return {"x": x, "y": y, "s": s, "c": c, "m": m}
-
- # Determine border style based on space.torus
- border_style = "solid" if not space.torus else (0, (5, 10))
+ # gather agent data
+ s_default = (180 / max(width, height)) ** 2
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
- # Set the border of the plot
- for spine in space_ax.spines.values():
- spine.set_linewidth(1.5)
- spine.set_color("black")
- spine.set_linestyle(border_style)
+ # this assumes that nodes are identified by an integer
+ # which is true for default nx graphs but might user changeable
+ pos = np.asarray(list(pos.values()))
+ arguments["loc"] = pos[arguments["loc"]]
+
+ # plot the agents
+ _scatter(ax, arguments)
+
+ # further styling
+ ax.set_axis_off()
+ ax.set_xlim(xmin=xmin - x_padding, xmax=xmax + x_padding)
+ ax.set_ylim(ymin=ymin - y_padding, ymax=ymax + y_padding)
+
+ if draw_grid:
+ # fixme we need to draw the empty nodes as well
+ edge_collection = nx.draw_networkx_edges(
+ graph, pos, ax=ax, alpha=0.5, style="--"
+ )
+ edge_collection.set_zorder(0)
+
+ return ax
+
+
+def draw_continuous_space(
+ space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None
+):
+ """Visualize a continuous space.
+
+ Args:
+ space: the space to visualize
+ agent_portrayal: a callable that is called with the agent and returns a dict
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
+
+ Returns:
+ Returns the Axes object with the plot drawn onto it.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
+
+ """
+ if ax is None:
+ fig, ax = plt.subplots()
+ # space related setup
width = space.x_max - space.x_min
x_padding = width / 20
height = space.y_max - space.y_min
y_padding = height / 20
- space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding)
- space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)
-
- # Portray and scatter the agents in the space
- _split_and_scatter(portray(space), space_ax)
-
-
-def _draw_voronoi(space, space_ax, agent_portrayal):
- def portray(g):
- x = []
- y = []
- s = [] # size
- c = [] # color
-
- for cell in g.all_cells:
- for agent in cell.agents:
- data = agent_portrayal(agent)
- x.append(cell.coordinate[0])
- y.append(cell.coordinate[1])
- if "size" in data:
- s.append(data["size"])
- if "color" in data:
- c.append(data["color"])
- out = {"x": x, "y": y}
- out["s"] = s
- if len(c) > 0:
- out["c"] = c
-
- return out
+
+ # gather agent data
+ s_default = (180 / max(width, height)) ** 2
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
+
+ # plot the agents
+ _scatter(ax, arguments)
+
+ # further visual styling
+ border_style = "solid" if not space.torus else (0, (5, 10))
+ for spine in ax.spines.values():
+ spine.set_linewidth(1.5)
+ spine.set_color("black")
+ spine.set_linestyle(border_style)
+
+ ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding)
+ ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)
+
+ return ax
+
+
+def draw_voroinoi_grid(
+ space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None
+):
+ """Visualize a voronoi grid.
+
+ Args:
+ space: the space to visualize
+ agent_portrayal: a callable that is called with the agent and returns a dict
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
+
+ Returns:
+ Returns the Axes object with the plot drawn onto it.
+
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
+
+ """
+ if ax is None:
+ fig, ax = plt.subplots()
x_list = [i[0] for i in space.centroids_coordinates]
y_list = [i[1] for i in space.centroids_coordinates]
@@ -283,56 +573,49 @@ def portray(g):
x_padding = width / 20
height = y_max - y_min
y_padding = height / 20
- space_ax.set_xlim(x_min - x_padding, x_max + x_padding)
- space_ax.set_ylim(y_min - y_padding, y_max + y_padding)
- space_ax.scatter(**portray(space))
+
+ s_default = (180 / max(width, height)) ** 2
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
+
+ ax.set_xlim(x_min - x_padding, x_max + x_padding)
+ ax.set_ylim(y_min - y_padding, y_max + y_padding)
+
+ _scatter(ax, arguments)
for cell in space.all_cells:
polygon = cell.properties["polygon"]
- space_ax.fill(
+ ax.fill(
*zip(*polygon),
alpha=min(1, cell.properties[space.cell_coloring_property]),
c="red",
+ zorder=0,
) # Plot filled polygon
- space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black
-
-
-def _draw_discrete_space_grid(space: Grid, space_ax, agent_portrayal):
- if space._ndims != 2:
- raise ValueError("Space must be 2D")
-
- def portray(g):
- x = []
- y = []
- s = [] # size
- c = [] # color
-
- for cell in g.all_cells:
- for agent in cell.agents:
- data = agent_portrayal(agent)
- x.append(cell.coordinate[0])
- y.append(cell.coordinate[1])
- if "size" in data:
- s.append(data["size"])
- if "color" in data:
- c.append(data["color"])
- out = {"x": x, "y": y}
- out["s"] = s
- if len(c) > 0:
- out["c"] = c
-
- return out
-
- space_ax.set_xlim(0, space.width)
- space_ax.set_ylim(0, space.height)
-
- # Draw grid lines
- for x in range(space.width + 1):
- space_ax.axvline(x, color="gray", linestyle=":")
- for y in range(space.height + 1):
- space_ax.axhline(y, color="gray", linestyle=":")
-
- space_ax.scatter(**portray(space))
+ ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black
+
+ return ax
+
+
+def _scatter(ax: Axes, arguments):
+ """Helper function for plotting the agents."""
+ loc = arguments.pop("loc")
+
+ x = loc[:, 0]
+ y = loc[:, 1]
+ marker = arguments.pop("marker")
+ zorder = arguments.pop("zorder")
+
+ for mark in np.unique(marker):
+ mark_mask = marker == mark
+ for z_order in np.unique(zorder):
+ zorder_mask = z_order == zorder
+ logical = mark_mask & zorder_mask
+ ax.scatter(
+ x[logical],
+ y[logical],
+ marker=mark,
+ zorder=z_order,
+ **{k: v[logical] for k, v in arguments.items()},
+ )
def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
diff --git a/tests/test_components_matplotlib.py b/tests/test_components_matplotlib.py
new file mode 100644
index 00000000000..c85dd1ce292
--- /dev/null
+++ b/tests/test_components_matplotlib.py
@@ -0,0 +1,158 @@
+"""tests for matplotlib components."""
+
+import matplotlib.pyplot as plt
+
+from mesa import Agent, Model
+from mesa.experimental.cell_space import (
+ CellAgent,
+ HexGrid,
+ Network,
+ OrthogonalMooreGrid,
+ VoronoiGrid,
+)
+from mesa.space import (
+ ContinuousSpace,
+ HexSingleGrid,
+ NetworkGrid,
+ PropertyLayer,
+ SingleGrid,
+)
+from mesa.visualization.components.matplotlib import (
+ draw_continuous_space,
+ draw_hex_grid,
+ draw_network,
+ draw_orthogonal_grid,
+ draw_property_layers,
+ draw_voroinoi_grid,
+)
+
+
+def agent_portrayal(agent):
+ """Simple portrayal of an agent.
+
+ Args:
+ agent (Agent): The agent to portray
+
+ """
+ return {
+ "s": 10,
+ "c": "tab:blue",
+ "marker": "s" if (agent.unique_id % 2) == 0 else "o",
+ }
+
+
+def test_draw_hex_grid():
+ """Test drawing hexgrids."""
+ model = Model(seed=42)
+ grid = HexSingleGrid(10, 10, torus=True)
+ for _ in range(10):
+ agent = Agent(model)
+ grid.move_to_empty(agent)
+
+ fig, ax = plt.subplots()
+ draw_hex_grid(grid, agent_portrayal, ax)
+
+ model = Model(seed=42)
+ grid = HexGrid((10, 10), torus=True, random=model.random, capacity=1)
+ for _ in range(10):
+ agent = CellAgent(model)
+ agent.cell = grid.select_random_empty_cell()
+
+ fig, ax = plt.subplots()
+ draw_hex_grid(grid, agent_portrayal, ax)
+
+
+def test_draw_voroinoi_grid():
+ """Test drawing voroinoi grids."""
+ model = Model(seed=42)
+
+ coordinates = model.rng.random((100, 2)) * 10
+
+ grid = VoronoiGrid(coordinates.tolist(), random=model.random, capacity=1)
+ for _ in range(10):
+ agent = CellAgent(model)
+ agent.cell = grid.select_random_empty_cell()
+
+ fig, ax = plt.subplots()
+ draw_voroinoi_grid(grid, agent_portrayal, ax)
+
+
+def test_draw_orthogonal_grid():
+ """Test drawing orthogonal grids."""
+ model = Model(seed=42)
+ grid = SingleGrid(10, 10, torus=True)
+ for _ in range(10):
+ agent = Agent(model)
+ grid.move_to_empty(agent)
+
+ fig, ax = plt.subplots()
+ draw_orthogonal_grid(grid, agent_portrayal, ax)
+
+ model = Model(seed=42)
+ grid = OrthogonalMooreGrid((10, 10), torus=True, random=model.random, capacity=1)
+ for _ in range(10):
+ agent = CellAgent(model)
+ agent.cell = grid.select_random_empty_cell()
+
+ fig, ax = plt.subplots()
+ draw_orthogonal_grid(grid, agent_portrayal, ax)
+
+
+def test_draw_continuous_space():
+ """Test drawing continuous space."""
+ model = Model(seed=42)
+ space = ContinuousSpace(10, 10, torus=True)
+ for _ in range(10):
+ x = model.random.random() * 10
+ y = model.random.random() * 10
+ agent = Agent(model)
+ space.place_agent(agent, (x, y))
+
+ fig, ax = plt.subplots()
+ draw_continuous_space(space, agent_portrayal, ax)
+
+
+def test_draw_network():
+ """Test drawing network."""
+ import networkx as nx
+
+ n = 10
+ m = 20
+ seed = 42
+ graph = nx.gnm_random_graph(n, m, seed=seed)
+
+ model = Model(seed=42)
+ grid = NetworkGrid(graph)
+ for _ in range(10):
+ agent = Agent(model)
+ pos = agent.random.randint(0, len(graph.nodes) - 1)
+ grid.place_agent(agent, pos)
+
+ fig, ax = plt.subplots()
+ draw_network(grid, agent_portrayal, ax)
+
+ model = Model(seed=42)
+ grid = Network(graph, random=model.random, capacity=1)
+ for _ in range(10):
+ agent = CellAgent(model)
+ agent.cell = grid.select_random_empty_cell()
+
+ fig, ax = plt.subplots()
+ draw_network(grid, agent_portrayal, ax)
+
+
+def test_draw_property_layers():
+ """Test drawing property layers."""
+ model = Model(seed=42)
+ grid = SingleGrid(10, 10, torus=True)
+ grid.add_property_layer(PropertyLayer("test", grid.width, grid.height, 0))
+
+ fig, ax = plt.subplots()
+ draw_property_layers(grid, {"test": {"colormap": "viridis", "colorbar": True}}, ax)
+
+ model = Model(seed=42)
+ grid = OrthogonalMooreGrid((10, 10), torus=True, random=model.random, capacity=1)
+ grid.add_property_layer(PropertyLayer("test", grid.width, grid.height, 0))
+
+ fig, ax = plt.subplots()
+ draw_property_layers(grid, {"test": {"colormap": "viridis", "colorbar": True}}, ax)
diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py
index a0d2b449399..af6badd0bc2 100644
--- a/tests/test_solara_viz.py
+++ b/tests/test_solara_viz.py
@@ -8,7 +8,7 @@
import mesa
import mesa.visualization.components.altair
import mesa.visualization.components.matplotlib
-from mesa.visualization.components.matplotlib import make_space_matplotlib
+from mesa.visualization.components.matplotlib import make_space_component
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
@@ -97,16 +97,16 @@ def test_call_space_drawer(mocker): # noqa: D103
mocker.patch.object(mesa.Model, "__init__", return_value=None)
agent_portrayal = {
- "Shape": "circle",
+ "marker": "circle",
"color": "gray",
}
propertylayer_portrayal = None
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
- solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)]))
+ solara.render(SolaraViz(model, components=[make_space_component(agent_portrayal)]))
# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(
- model, agent_portrayal, propertylayer_portrayal
+ model, agent_portrayal, propertylayer_portrayal, post_process=None
)
# specify no space should be drawn
@@ -132,7 +132,7 @@ def drawer(model):
centroids_coordinates=[(0, 1), (0, 0), (1, 0)],
)
solara.render(
- SolaraViz(voronoi_model, components=[make_space_matplotlib(agent_portrayal)])
+ SolaraViz(voronoi_model, components=[make_space_component(agent_portrayal)])
)