diff --git a/docs/apis/visualization.md b/docs/apis/visualization.md index 10349aaf240..7815ec568b8 100644 --- a/docs/apis/visualization.md +++ b/docs/apis/visualization.md @@ -19,3 +19,22 @@ For a detailed tutorial, please refer to our [Visualization Tutorial](../tutoria :undoc-members: :show-inheritance: ``` + + +## Matplotlib-based components + +```{eval-rst} +.. automodule:: mesa.visualization.components.matplotlib + :members: + :undoc-members: + :show-inheritance: +``` + +## Altair-based components + +```{eval-rst} +.. automodule:: mesa.visualization.components.altair + :members: + :undoc-members: + :show-inheritance: +``` \ No newline at end of file diff --git a/docs/migration_guide.md b/docs/migration_guide.md index b4f54c3dfd7..016053e84fb 100644 --- a/docs/migration_guide.md +++ b/docs/migration_guide.md @@ -268,9 +268,9 @@ from mesa.experimental import SolaraViz SolaraViz(model_cls, model_params, agent_portrayal=agent_portrayal) # new -from mesa.visualization import SolaraViz, make_space_matplotlib +from mesa.visualization import SolaraViz, make_space_component -SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)]) +SolaraViz(model, components=[make_space_component(agent_portrayal)]) ``` #### Plotting "measures" diff --git a/docs/overview.md b/docs/overview.md index 17ef67302fe..7d0c750ed84 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -168,11 +168,13 @@ The results are returned as a list of dictionaries, which can be easily converte Mesa now uses a new browser-based visualization system called SolaraViz. This allows for interactive, customizable visualizations of your models. Here's a basic example of how to set up a visualization: ```python -from mesa.visualization import SolaraViz, make_space_matplotlib, make_plot_measure +from mesa.visualization import SolaraViz, make_space_component, make_plot_measure + def agent_portrayal(agent): return {"color": "blue", "size": 50} + model_params = { "N": { "type": "SliderInt", @@ -187,7 +189,7 @@ model_params = { page = SolaraViz( MyModel, [ - make_space_matplotlib(agent_portrayal), + make_space_component(agent_portrayal), make_plot_measure("mean_age") ], model_params=model_params diff --git a/docs/tutorials/visualization_tutorial.ipynb b/docs/tutorials/visualization_tutorial.ipynb index 7460ca30a19..aff166e9b82 100644 --- a/docs/tutorials/visualization_tutorial.ipynb +++ b/docs/tutorials/visualization_tutorial.ipynb @@ -3,9 +3,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "# Visualization Tutorial" - ] + "source": "# Visualization Tutorial" }, { "cell_type": "markdown", @@ -52,40 +50,50 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:46.075682Z", + "start_time": "2024-10-29T19:38:45.449918Z" + } + }, + "source": [ + "import mesa\n", + "print(f\"Mesa version: {mesa.__version__}\")\n", + "\n", + "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n", + "\n", + "# Import the local MoneyModel.py\n", + "from MoneyModel import MoneyModel\n" + ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Mesa version: 3.0.0b1\n" + "Mesa version: 3.0.0b2\n" ] } ], - "source": [ - "import mesa\n", - "print(f\"Mesa version: {mesa.__version__}\")\n", - "\n", - "from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib\n", - "# Import the local MoneyModel.py\n", - "from MoneyModel import MoneyModel\n" - ] + "execution_count": 1 }, { "cell_type": "code", - "execution_count": null, "metadata": { - "tags": [] + "tags": [], + "ExecuteTime": { + "end_time": "2024-10-29T19:38:46.079286Z", + "start_time": "2024-10-29T19:38:46.076984Z" + } }, - "outputs": [], "source": [ "def agent_portrayal(agent):\n", " return {\n", " \"color\": \"tab:blue\",\n", " \"size\": 50,\n", " }" - ] + ], + "outputs": [], + "execution_count": 2 }, { "cell_type": "markdown", @@ -96,9 +104,12 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:46.081662Z", + "start_time": "2024-10-29T19:38:46.079838Z" + } + }, "source": [ "model_params = {\n", " \"n\": {\n", @@ -112,7 +123,9 @@ " \"width\": 10,\n", " \"height\": 10,\n", "}" - ] + ], + "outputs": [], + "execution_count": 3 }, { "cell_type": "markdown", @@ -130,16 +143,18 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": { - "tags": [] + "tags": [], + "ExecuteTime": { + "end_time": "2024-10-29T19:38:46.864371Z", + "start_time": "2024-10-29T19:38:46.082810Z" + } }, - "outputs": [], "source": [ "# Create initial model instance\n", "model1 = MoneyModel(50, 10, 10)\n", "\n", - "SpaceGraph = make_space_matplotlib(agent_portrayal)\n", + "SpaceGraph = make_space_component(agent_portrayal)\n", "GiniPlot = make_plot_measure(\"Gini\")\n", "\n", "page = SolaraViz(\n", @@ -150,7 +165,27 @@ ")\n", "# This is required to render the visualization in the Jupyter notebook\n", "page" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "Cannot show ipywidgets in text" + ], + "text/html": [ + "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter +)." + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "c9f2ef2b5a24483c92fa129213414a2c" + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 4 }, { "cell_type": "markdown", @@ -169,23 +204,39 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:46.867576Z", + "start_time": "2024-10-29T19:38:46.865205Z" + } + }, "source": [ "import mesa\n", "print(f\"Mesa version: {mesa.__version__}\")\n", "\n", - "from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib\n", + "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n", "# Import the local MoneyModel.py\n", "from MoneyModel import MoneyModel\n" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mesa version: 3.0.0b2\n" + ] + } + ], + "execution_count": 5 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:46.870617Z", + "start_time": "2024-10-29T19:38:46.868336Z" + } + }, "source": [ "def agent_portrayal(agent):\n", " size = 10\n", @@ -207,18 +258,23 @@ " \"width\": 10,\n", " \"height\": 10,\n", "}" - ] + ], + "outputs": [], + "execution_count": 6 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:47.881911Z", + "start_time": "2024-10-29T19:38:46.871328Z" + } + }, "source": [ "# Create initial model instance\n", "model1 = MoneyModel(50, 10, 10)\n", "\n", - "SpaceGraph = make_space_matplotlib(agent_portrayal)\n", + "SpaceGraph = make_space_component(agent_portrayal)\n", "GiniPlot = make_plot_measure(\"Gini\")\n", "\n", "page = SolaraViz(\n", @@ -229,7 +285,27 @@ ")\n", "# This is required to render the visualization in the Jupyter notebook\n", "page" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "Cannot show ipywidgets in text" + ], + "text/html": [ + "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter +)." + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "da8518ec9ce74c068288bec0c8d3793e" + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 7 }, { "cell_type": "markdown", @@ -250,9 +326,12 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:47.885386Z", + "start_time": "2024-10-29T19:38:47.882808Z" + } + }, "source": [ "import mesa\n", "print(f\"Mesa version: {mesa.__version__}\")\n", @@ -260,16 +339,29 @@ "from matplotlib.figure import Figure\n", "\n", "from mesa.visualization.utils import update_counter\n", - "from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib\n", + "from mesa.visualization import SolaraViz, make_plot_measure, make_space_component\n", "# Import the local MoneyModel.py\n", "from MoneyModel import MoneyModel\n" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mesa version: 3.0.0b2\n" + ] + } + ], + "execution_count": 8 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:47.888491Z", + "start_time": "2024-10-29T19:38:47.886217Z" + } + }, "source": [ "def agent_portrayal(agent):\n", " size = 10\n", @@ -291,7 +383,9 @@ " \"width\": 10,\n", " \"height\": 10,\n", "}" - ] + ], + "outputs": [], + "execution_count": 9 }, { "cell_type": "markdown", @@ -302,9 +396,12 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:47.893643Z", + "start_time": "2024-10-29T19:38:47.891084Z" + } + }, "source": [ "@solara.component\n", "def Histogram(model):\n", @@ -318,26 +415,36 @@ " # because plt.hist is not thread-safe.\n", " ax.hist(wealth_vals, bins=10)\n", " solara.FigureMatplotlib(fig)" - ] + ], + "outputs": [], + "execution_count": 10 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:47.896565Z", + "start_time": "2024-10-29T19:38:47.894387Z" + } + }, "source": [ "# Create initial model instance\n", "model1 = MoneyModel(50, 10, 10)\n", "\n", - "SpaceGraph = make_space_matplotlib(agent_portrayal)\n", + "SpaceGraph = make_space_component(agent_portrayal)\n", "GiniPlot = make_plot_measure(\"Gini\")" - ] + ], + "outputs": [], + "execution_count": 11 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:49.471838Z", + "start_time": "2024-10-29T19:38:47.897295Z" + } + }, "source": [ "page = SolaraViz(\n", " model1,\n", @@ -347,7 +454,27 @@ ")\n", "# This is required to render the visualization in the Jupyter notebook\n", "page" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "Cannot show ipywidgets in text" + ], + "text/html": [ + "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter +)." + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "bc71b89ee5684038a194eee4c36f4a4c" + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 12 }, { "cell_type": "markdown", @@ -358,12 +485,35 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-29T19:38:49.505725Z", + "start_time": "2024-10-29T19:38:49.472599Z" + } + }, "source": [ "Histogram(model1)" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "Cannot show ipywidgets in text" + ], + "text/html": [ + "Cannot show widget. You probably want to rerun the code cell above (Click in the code cell, and press Shift+Enter +)." + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "0491f167a1434a92b78535078bd082a8" + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 13 }, { "cell_type": "markdown", diff --git a/mesa/examples/advanced/epstein_civil_violence/app.py b/mesa/examples/advanced/epstein_civil_violence/app.py index 862ca6220d8..538ef186f57 100644 --- a/mesa/examples/advanced/epstein_civil_violence/app.py +++ b/mesa/examples/advanced/epstein_civil_violence/app.py @@ -8,7 +8,7 @@ Slider, SolaraViz, make_plot_measure, - make_space_matplotlib, + make_space_component, ) COP_COLOR = "#000000" @@ -47,7 +47,7 @@ def citizen_cop_portrayal(agent): "max_jail_term": Slider("Max Jail Term", 30, 0, 50, 1), } -space_component = make_space_matplotlib(citizen_cop_portrayal) +space_component = make_space_component(citizen_cop_portrayal) chart_component = make_plot_measure([state.name.lower() for state in CitizenState]) epstein_model = EpsteinCivilViolence() diff --git a/mesa/examples/advanced/pd_grid/app.py b/mesa/examples/advanced/pd_grid/app.py index c8ceec9fe16..6edf8140536 100644 --- a/mesa/examples/advanced/pd_grid/app.py +++ b/mesa/examples/advanced/pd_grid/app.py @@ -3,7 +3,7 @@ """ from mesa.examples.advanced.pd_grid.model import PdGrid -from mesa.visualization import SolaraViz, make_plot_measure, make_space_matplotlib +from mesa.visualization import SolaraViz, make_plot_measure, make_space_component from mesa.visualization.UserParam import Slider @@ -32,7 +32,7 @@ def pd_agent_portrayal(agent): # Create grid visualization component using Altair -grid_viz = make_space_matplotlib(agent_portrayal=pd_agent_portrayal) +grid_viz = make_space_component(agent_portrayal=pd_agent_portrayal) # Create plot for tracking cooperating agents over time plot_component = make_plot_measure("Cooperating_Agents") diff --git a/mesa/examples/advanced/sugarscape_g1mt/app.py b/mesa/examples/advanced/sugarscape_g1mt/app.py index 7c8cc2cfead..752998891bd 100644 --- a/mesa/examples/advanced/sugarscape_g1mt/app.py +++ b/mesa/examples/advanced/sugarscape_g1mt/app.py @@ -1,11 +1,3 @@ -import os.path -import sys - -sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../")) -) - - import numpy as np import solara from matplotlib.figure import Figure diff --git a/mesa/examples/advanced/wolf_sheep/app.py b/mesa/examples/advanced/wolf_sheep/app.py index b5ac6e8bf47..a8c0a1e9c49 100644 --- a/mesa/examples/advanced/wolf_sheep/app.py +++ b/mesa/examples/advanced/wolf_sheep/app.py @@ -4,12 +4,9 @@ Slider, SolaraViz, make_plot_measure, - make_space_matplotlib, + make_space_component, ) -WOLF_COLOR = "#000000" -SHEEP_COLOR = "#648FFF" - def wolf_sheep_portrayal(agent): if agent is None: @@ -17,23 +14,23 @@ def wolf_sheep_portrayal(agent): portrayal = { "size": 25, - "shape": "s", # square marker } if isinstance(agent, Wolf): - portrayal["color"] = WOLF_COLOR - portrayal["Layer"] = 3 + portrayal["color"] = "tab:red" + portrayal["marker"] = "o" + portrayal["zorder"] = 2 elif isinstance(agent, Sheep): - portrayal["color"] = SHEEP_COLOR - portrayal["Layer"] = 2 + portrayal["color"] = "tab:cyan" + portrayal["marker"] = "o" + portrayal["zorder"] = 2 elif isinstance(agent, GrassPatch): if agent.fully_grown: - portrayal["color"] = "#00FF00" + portrayal["color"] = "tab:green" else: - portrayal["color"] = "#84e184" - # portrayal["shape"] = "rect" - # portrayal["Filled"] = "true" - portrayal["Layer"] = 1 + portrayal["color"] = "tab:brown" + portrayal["marker"] = "s" + portrayal["size"] = 75 return portrayal @@ -62,10 +59,20 @@ def wolf_sheep_portrayal(agent): } -space_component = make_space_matplotlib(wolf_sheep_portrayal) -lineplot_component = make_plot_measure(["Wolves", "Sheep", "Grass"]) +def post_process(ax): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + + +space_component = make_space_component( + wolf_sheep_portrayal, draw_grid=False, post_process=post_process +) +lineplot_component = make_plot_measure( + {"Wolves": "tab:orange", "Sheep": "tab:cyan", "Grass": "tab:green"} +) -model = WolfSheep() +model = WolfSheep(grass=True) page = SolaraViz( diff --git a/mesa/examples/basic/boid_flockers/app.py b/mesa/examples/basic/boid_flockers/app.py index 482d582b8ba..bcecb0a3ebd 100644 --- a/mesa/examples/basic/boid_flockers/app.py +++ b/mesa/examples/basic/boid_flockers/app.py @@ -1,5 +1,5 @@ from mesa.examples.basic.boid_flockers.model import BoidFlockers -from mesa.visualization import Slider, SolaraViz, make_space_matplotlib +from mesa.visualization import Slider, SolaraViz, make_space_component def boid_draw(agent): @@ -51,7 +51,7 @@ def boid_draw(agent): page = SolaraViz( model, - [make_space_matplotlib(agent_portrayal=boid_draw)], + [make_space_component(agent_portrayal=boid_draw)], model_params=model_params, name="Boid Flocking Model", ) diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py index 7e3f41e64de..2ab6d06bf73 100644 --- a/mesa/examples/basic/boltzmann_wealth_model/app.py +++ b/mesa/examples/basic/boltzmann_wealth_model/app.py @@ -2,7 +2,7 @@ from mesa.visualization import ( SolaraViz, make_plot_measure, - make_space_matplotlib, + make_space_component, ) @@ -36,7 +36,7 @@ def agent_portrayal(agent): # Under the hood these are just classes that receive the model instance. # You can also author your own visualization elements, which can also be functions # that receive the model instance and return a valid solara component. -SpaceGraph = make_space_matplotlib(agent_portrayal) +SpaceGraph = make_space_component(agent_portrayal) GiniPlot = make_plot_measure("Gini") # Create the SolaraViz page. This will automatically create a server and display the diff --git a/mesa/examples/basic/conways_game_of_life/app.py b/mesa/examples/basic/conways_game_of_life/app.py index 2c9dace8635..7a45125a30a 100644 --- a/mesa/examples/basic/conways_game_of_life/app.py +++ b/mesa/examples/basic/conways_game_of_life/app.py @@ -1,12 +1,12 @@ from mesa.examples.basic.conways_game_of_life.model import ConwaysGameOfLife from mesa.visualization import ( SolaraViz, - make_space_matplotlib, + make_space_component, ) def agent_portrayal(agent): - return {"color": "white" if agent.state == 0 else "black"} + return {"c": "white" if agent.state == 0 else "black", "marker": "s"} model_params = { @@ -22,7 +22,7 @@ def agent_portrayal(agent): # Under the hood these are just classes that receive the model instance. # You can also author your own visualization elements, which can also be functions # that receive the model instance and return a valid solara component. -SpaceGraph = make_space_matplotlib(agent_portrayal) +SpaceGraph = make_space_component(agent_portrayal) # Create the SolaraViz page. This will automatically create a server and display the diff --git a/mesa/examples/basic/schelling/app.py b/mesa/examples/basic/schelling/app.py index 72ae6ddc1ec..53fab7ba0f0 100644 --- a/mesa/examples/basic/schelling/app.py +++ b/mesa/examples/basic/schelling/app.py @@ -5,7 +5,7 @@ Slider, SolaraViz, make_plot_measure, - make_space_matplotlib, + make_space_component, ) @@ -33,7 +33,7 @@ def agent_portrayal(agent): page = SolaraViz( model1, components=[ - make_space_matplotlib(agent_portrayal), + make_space_component(agent_portrayal), make_plot_measure("happy"), get_happy_agents, ], diff --git a/mesa/examples/basic/virus_on_network/app.py b/mesa/examples/basic/virus_on_network/app.py index 0183d256790..7cf54f308d5 100644 --- a/mesa/examples/basic/virus_on_network/app.py +++ b/mesa/examples/basic/virus_on_network/app.py @@ -9,7 +9,7 @@ VirusOnNetwork, number_infected, ) -from mesa.visualization import Slider, SolaraViz, make_space_matplotlib +from mesa.visualization import Slider, SolaraViz, make_space_component def agent_portrayal(graph): @@ -119,7 +119,7 @@ def make_plot(model): ), } -SpacePlot = make_space_matplotlib(agent_portrayal) +SpacePlot = make_space_component(agent_portrayal) model1 = VirusOnNetwork() diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index d6e50c37e36..0e1875c751c 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -1,7 +1,7 @@ """Solara based visualization for Mesa models.""" from .components.altair import make_space_altair -from .components.matplotlib import make_plot_measure, make_space_matplotlib +from .components.matplotlib import make_plot_measure, make_space_component from .solara_viz import JupyterViz, SolaraViz from .UserParam import Slider @@ -10,6 +10,6 @@ "SolaraViz", "Slider", "make_space_altair", - "make_space_matplotlib", + "make_space_component", "make_plot_measure", ] diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index bea633d6c8b..7e9982a7387 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,27 +1,63 @@ """Matplotlib based solara components for visualization MESA spaces and plots.""" +import itertools +import math import warnings +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt import networkx as nx import numpy as np import solara +from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable +from matplotlib.collections import PatchCollection from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba from matplotlib.figure import Figure +from matplotlib.patches import RegularPolygon import mesa -from mesa.experimental.cell_space import Grid, VoronoiGrid -from mesa.space import PropertyLayer +from mesa.experimental.cell_space import ( + OrthogonalMooreGrid, + OrthogonalVonNeumannGrid, + VoronoiGrid, +) +from mesa.space import ( + ContinuousSpace, + HexMultiGrid, + HexSingleGrid, + MultiGrid, + NetworkGrid, + PropertyLayer, + SingleGrid, +) from mesa.visualization.utils import update_counter +# For typing +OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid +HexGrid = HexSingleGrid | HexMultiGrid | mesa.experimental.cell_space.HexGrid +Network = NetworkGrid | mesa.experimental.cell_space.Network -def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None): + +def make_space_component( + agent_portrayal: Callable | None = None, + propertylayer_portrayal: dict | None = None, + post_process: Callable | None = None, + **space_drawing_kwargs, +): """Create a Matplotlib-based space visualization component. Args: - agent_portrayal (function): Function to portray agents - propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications + agent_portrayal: Function to portray agents. + propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications + post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) + space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See + the functions for drawing the various spaces for further details. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + Returns: function: A function that creates a SpaceMatplotlib component @@ -29,10 +65,16 @@ def make_space_matplotlib(agent_portrayal=None, propertylayer_portrayal=None): if agent_portrayal is None: def agent_portrayal(a): - return {"id": a.unique_id} + return {} def MakeSpaceMatplotlib(model): - return SpaceMatplotlib(model, agent_portrayal, propertylayer_portrayal) + return SpaceMatplotlib( + model, + agent_portrayal, + propertylayer_portrayal, + post_process=post_process, + **space_drawing_kwargs, + ) return MakeSpaceMatplotlib @@ -43,48 +85,157 @@ def SpaceMatplotlib( agent_portrayal, propertylayer_portrayal, dependencies: list[any] | None = None, + post_process: Callable | None = None, + **space_drawing_kwargs, ): """Create a Matplotlib-based space visualization component.""" update_counter.get() - space_fig = Figure() - space_ax = space_fig.subplots() + space = getattr(model, "grid", None) if space is None: space = getattr(model, "space", None) + fig = Figure() + ax = fig.add_subplot() + + draw_space( + space, + agent_portrayal, + propertylayer_portrayal=propertylayer_portrayal, + ax=ax, + post_process=post_process, + **space_drawing_kwargs, + ) + + solara.FigureMatplotlib( + fig, format="png", bbox_inches="tight", dependencies=dependencies + ) + + +def collect_agent_data( + space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid, + agent_portrayal: Callable, + color="tab:blue", + size=25, + marker="o", + zorder: int = 1, +): + """Collect the plotting data for all agents in the space. + + Args: + space: The space containing the Agents. + agent_portrayal: A callable that is called with the agent and returns a dict + color: default color + size: default size + marker: default marker + zorder: default zorder + + agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order), + and marker (marker style) + + """ + arguments = {"s": [], "c": [], "marker": [], "zorder": [], "loc": []} + + for agent in space.agents: + portray = agent_portrayal(agent) + loc = agent.pos + if loc is None: + loc = agent.cell.coordinate + + arguments["loc"].append(loc) + arguments["s"].append(portray.pop("size", size)) + arguments["c"].append(portray.pop("color", color)) + arguments["marker"].append(portray.pop("marker", marker)) + arguments["zorder"].append(portray.pop("zorder", zorder)) + + if len(portray) > 0: + ignored_fields = list(portray.keys()) + msg = ", ".join(ignored_fields) + warnings.warn( + f"the following fields are not used in agent portrayal and thus ignored: {msg}.", + stacklevel=2, + ) + + return {k: np.asarray(v) for k, v in arguments.items()} + + +def draw_space( + space, + agent_portrayal: Callable, + propertylayer_portrayal: dict | None = None, + ax: Axes | None = None, + post_process: Callable | None = None, + **space_drawing_kwargs, +): + """Draw a Matplotlib-based visualization of the space. + + Args: + space: the space of the mesa model + agent_portrayal: A callable that returns a dict specifying how to show the agent + propertylayer_portrayal: a dict specifying how to show propertylayer(s) + ax: the axes upon which to draw the plot + post_process: a callable called with the Axes instance + postprocess: a user-specified callable to do post-processing called with the Axes instance. This callable + can be used for any further fine-tuning of the plot (e.g., changing ticks, etc.) + space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space. + + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + """ + if ax is None: + fig, ax = plt.subplots() + # https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching match space: - case mesa.space._Grid(): - _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model) + case mesa.space._Grid() | OrthogonalMooreGrid() | OrthogonalVonNeumannGrid(): + draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs) + case HexSingleGrid() | HexMultiGrid() | mesa.experimental.cell_space.HexGrid(): + draw_hex_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs) + case mesa.space.NetworkGrid() | mesa.experimental.cell_space.Network(): + draw_network(space, agent_portrayal, ax=ax, **space_drawing_kwargs) case mesa.space.ContinuousSpace(): - _draw_continuous_space(space, space_ax, agent_portrayal, model) - case mesa.space.NetworkGrid(): - _draw_network_grid(space, space_ax, agent_portrayal) + draw_continuous_space(space, agent_portrayal, ax=ax) case VoronoiGrid(): - _draw_voronoi(space, space_ax, agent_portrayal) - case Grid(): # matches OrthogonalMooreGrid, OrthogonalVonNeumannGrid, and Hexgrid - # fixme add a separate draw method for hexgrids in the future - _draw_discrete_space_grid(space, space_ax, agent_portrayal) - case None: - if propertylayer_portrayal: - draw_property_layers(space_ax, space, propertylayer_portrayal, model) + draw_voroinoi_grid(space, agent_portrayal, ax=ax) - solara.FigureMatplotlib( - space_fig, format="png", bbox_inches="tight", dependencies=dependencies - ) + if propertylayer_portrayal: + draw_property_layers(space, propertylayer_portrayal, ax=ax) + + if post_process is not None: + post_process(ax=ax) + return ax -def draw_property_layers(ax, space, propertylayer_portrayal, model): + +def draw_property_layers( + space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes +): """Draw PropertyLayers on the given axes. Args: - ax (matplotlib.axes.Axes): The axes to draw on. space (mesa.space._Grid): The space containing the PropertyLayers. - propertylayer_portrayal (dict): Dictionary of PropertyLayer portrayal specifications. - model (mesa.Model): The model instance. + propertylayer_portrayal (dict): the key is the name of the layer, the value is a dict with + fields specifying how the layer is to be portrayed + ax (matplotlib.axes.Axes): The axes to draw on. + + Notes: + valid fields in in the inner dict of propertylayer_portrayal are "alpha", "vmin", "vmax", "color" or "colormap", and "colorbar" + so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}` + """ + try: + # old style spaces + property_layers = space.properties + except AttributeError: + # new style spaces + property_layers = space.property_layers + for layer_name, portrayal in propertylayer_portrayal.items(): - layer = getattr(model, layer_name, None) + layer = property_layers.get(layer_name, None) if not isinstance(layer, PropertyLayer): continue @@ -116,7 +267,6 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model): ) im = ax.imshow( rgba_data.transpose(1, 0, 2), - extent=(0, width, 0, height), origin="lower", ) if colorbar: @@ -135,7 +285,6 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model): alpha=alpha, vmin=vmin, vmax=vmax, - extent=(0, width, 0, height), origin="lower", ) if colorbar: @@ -146,131 +295,272 @@ def draw_property_layers(ax, space, propertylayer_portrayal, model): ) -def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model): - if propertylayer_portrayal: - draw_property_layers(space_ax, space, propertylayer_portrayal, model) +def draw_orthogonal_grid( + space: OrthogonalGrid, + agent_portrayal: Callable, + ax: Axes | None = None, + draw_grid: bool = True, +): + """Visualize a orthogonal grid. + + Args: + space: the space to visualize + agent_portrayal: a callable that is called with the agent and returns a dict + ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + draw_grid: whether to draw the grid - agent_data = _get_agent_data(space, agent_portrayal) + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. - space_ax.set_xlim(0, space.width) - space_ax.set_ylim(0, space.height) - _split_and_scatter(agent_data, space_ax) + """ + if ax is None: + fig, ax = plt.subplots() - # Draw grid lines - for x in range(space.width + 1): - space_ax.axvline(x, color="gray", linestyle=":") - for y in range(space.height + 1): - space_ax.axhline(y, color="gray", linestyle=":") + # gather agent data + s_default = (180 / max(space.width, space.height)) ** 2 + arguments = collect_agent_data(space, agent_portrayal, size=s_default) + # plot the agents + _scatter(ax, arguments) -def _get_agent_data(space, agent_portrayal): - """Helper function to get agent data for visualization.""" - x, y, s, c, m = [], [], [], [], [] - for agents, pos in space.coord_iter(): - if not agents: - continue - if not isinstance(agents, list): - agents = [agents] # noqa PLW2901 - for agent in agents: - data = agent_portrayal(agent) - x.append(pos[0] + 0.5) # Center the agent in the cell - y.append(pos[1] + 0.5) # Center the agent in the cell - default_size = (180 / max(space.width, space.height)) ** 2 - s.append(data.get("size", default_size)) - c.append(data.get("color", "tab:blue")) - m.append(data.get("shape", "o")) - return {"x": x, "y": y, "s": s, "c": c, "m": m} - - -def _split_and_scatter(portray_data, space_ax): - """Helper function to split and scatter agent data.""" - for marker in set(portray_data["m"]): - mask = [m == marker for m in portray_data["m"]] - space_ax.scatter( - [x for x, show in zip(portray_data["x"], mask) if show], - [y for y, show in zip(portray_data["y"], mask) if show], - s=[s for s, show in zip(portray_data["s"], mask) if show], - c=[c for c, show in zip(portray_data["c"], mask) if show], - marker=marker, + # further styling + ax.set_xlim(-0.5, space.width - 0.5) + ax.set_ylim(-0.5, space.height - 0.5) + + if draw_grid: + # Draw grid lines + for x in np.arange(-0.5, space.width - 0.5, 1): + ax.axvline(x, color="gray", linestyle=":") + for y in np.arange(-0.5, space.height - 0.5, 1): + ax.axhline(y, color="gray", linestyle=":") + + return ax + + +def draw_hex_grid( + space: HexGrid, + agent_portrayal: Callable, + ax: Axes | None = None, + draw_grid: bool = True, +): + """Visualize a hex grid. + + Args: + space: the space to visualize + agent_portrayal: a callable that is called with the agent and returns a dict + ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + draw_grid: whether to draw the grid + + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + """ + if ax is None: + fig, ax = plt.subplots() + + # gather data + s_default = (180 / max(space.width, space.height)) ** 2 + arguments = collect_agent_data(space, agent_portrayal, size=s_default) + + # for hexgrids we have to go from logical coordinates to visual coordinates + # this is a bit messy. + + # give all even rows an offset in the x direction + # give all rows an offset in the y direction + + # numbers here are based on a distance of 1 between centers of hexes + offset = math.sqrt(0.75) + + loc = arguments["loc"].astype(float) + + logical = np.mod(loc[:, 1], 2) == 0 + loc[:, 0][logical] += 0.5 + loc[:, 1] *= offset + arguments["loc"] = loc + + # plot the agents + _scatter(ax, arguments) + + # further styling and adding of grid + ax.set_xlim(-1, space.width + 0.5) + ax.set_ylim(-offset, space.height * offset) + + def setup_hexmesh( + width, + height, + ): + """Helper function for creating the hexmaesh.""" + # fixme: this should be done once, rather than in each update + # fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset) + + patches = [] + for x, y in itertools.product(range(width), range(height)): + if y % 2 == 0: + x += 0.5 # noqa: PLW2901 + y *= offset # noqa: PLW2901 + hex = RegularPolygon( + (x, y), + numVertices=6, + radius=math.sqrt(1 / 3), + orientation=np.radians(120), + ) + patches.append(hex) + mesh = PatchCollection( + patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1 + ) + return mesh + + if draw_grid: + # add grid + ax.add_collection( + setup_hexmesh( + space.width, + space.height, + ) ) + return ax -def _draw_network_grid(space, space_ax, agent_portrayal): +def draw_network( + space: Network, + agent_portrayal: Callable, + ax: Axes | None = None, + draw_grid: bool = True, + layout_alg=nx.spring_layout, + layout_kwargs=None, +): + """Visualize a network space. + + Args: + space: the space to visualize + agent_portrayal: a callable that is called with the agent and returns a dict + ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + draw_grid: whether to draw the grid + layout_alg: a networkx layout algorithm or other callable with the same behavior + layout_kwargs: a dictionary of keyword arguments for the layout algorithm + + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + """ + if ax is None: + fig, ax = plt.subplots() + if layout_kwargs is None: + layout_kwargs = {"seed": 0} + + # gather locations for nodes in network graph = space.G - pos = nx.spring_layout(graph, seed=0) - nx.draw( - graph, - ax=space_ax, - pos=pos, - **agent_portrayal(graph), - ) + pos = layout_alg(graph, **layout_kwargs) + x, y = list(zip(*pos.values())) + xmin, xmax = min(x), max(x) + ymin, ymax = min(y), max(y) + width = xmax - xmin + height = ymax - ymin + x_padding = width / 20 + y_padding = height / 20 -def _draw_continuous_space(space, space_ax, agent_portrayal, model): - def portray(space): - x = [] - y = [] - s = [] # size - c = [] # color - m = [] # shape - for agent in space._agent_to_index: - data = agent_portrayal(agent) - _x, _y = agent.pos - x.append(_x) - y.append(_y) - - # This is matplotlib's default marker size - default_size = 20 - size = data.get("size", default_size) - s.append(size) - color = data.get("color", "tab:blue") - c.append(color) - mark = data.get("shape", "o") - m.append(mark) - return {"x": x, "y": y, "s": s, "c": c, "m": m} - - # Determine border style based on space.torus - border_style = "solid" if not space.torus else (0, (5, 10)) + # gather agent data + s_default = (180 / max(width, height)) ** 2 + arguments = collect_agent_data(space, agent_portrayal, size=s_default) - # Set the border of the plot - for spine in space_ax.spines.values(): - spine.set_linewidth(1.5) - spine.set_color("black") - spine.set_linestyle(border_style) + # this assumes that nodes are identified by an integer + # which is true for default nx graphs but might user changeable + pos = np.asarray(list(pos.values())) + arguments["loc"] = pos[arguments["loc"]] + + # plot the agents + _scatter(ax, arguments) + + # further styling + ax.set_axis_off() + ax.set_xlim(xmin=xmin - x_padding, xmax=xmax + x_padding) + ax.set_ylim(ymin=ymin - y_padding, ymax=ymax + y_padding) + + if draw_grid: + # fixme we need to draw the empty nodes as well + edge_collection = nx.draw_networkx_edges( + graph, pos, ax=ax, alpha=0.5, style="--" + ) + edge_collection.set_zorder(0) + + return ax + + +def draw_continuous_space( + space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None +): + """Visualize a continuous space. + + Args: + space: the space to visualize + agent_portrayal: a callable that is called with the agent and returns a dict + ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + """ + if ax is None: + fig, ax = plt.subplots() + # space related setup width = space.x_max - space.x_min x_padding = width / 20 height = space.y_max - space.y_min y_padding = height / 20 - space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) - space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) - - # Portray and scatter the agents in the space - _split_and_scatter(portray(space), space_ax) - - -def _draw_voronoi(space, space_ax, agent_portrayal): - def portray(g): - x = [] - y = [] - s = [] # size - c = [] # color - - for cell in g.all_cells: - for agent in cell.agents: - data = agent_portrayal(agent) - x.append(cell.coordinate[0]) - y.append(cell.coordinate[1]) - if "size" in data: - s.append(data["size"]) - if "color" in data: - c.append(data["color"]) - out = {"x": x, "y": y} - out["s"] = s - if len(c) > 0: - out["c"] = c - - return out + + # gather agent data + s_default = (180 / max(width, height)) ** 2 + arguments = collect_agent_data(space, agent_portrayal, size=s_default) + + # plot the agents + _scatter(ax, arguments) + + # further visual styling + border_style = "solid" if not space.torus else (0, (5, 10)) + for spine in ax.spines.values(): + spine.set_linewidth(1.5) + spine.set_color("black") + spine.set_linestyle(border_style) + + ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) + ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) + + return ax + + +def draw_voroinoi_grid( + space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None +): + """Visualize a voronoi grid. + + Args: + space: the space to visualize + agent_portrayal: a callable that is called with the agent and returns a dict + ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots + + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + + """ + if ax is None: + fig, ax = plt.subplots() x_list = [i[0] for i in space.centroids_coordinates] y_list = [i[1] for i in space.centroids_coordinates] @@ -283,56 +573,49 @@ def portray(g): x_padding = width / 20 height = y_max - y_min y_padding = height / 20 - space_ax.set_xlim(x_min - x_padding, x_max + x_padding) - space_ax.set_ylim(y_min - y_padding, y_max + y_padding) - space_ax.scatter(**portray(space)) + + s_default = (180 / max(width, height)) ** 2 + arguments = collect_agent_data(space, agent_portrayal, size=s_default) + + ax.set_xlim(x_min - x_padding, x_max + x_padding) + ax.set_ylim(y_min - y_padding, y_max + y_padding) + + _scatter(ax, arguments) for cell in space.all_cells: polygon = cell.properties["polygon"] - space_ax.fill( + ax.fill( *zip(*polygon), alpha=min(1, cell.properties[space.cell_coloring_property]), c="red", + zorder=0, ) # Plot filled polygon - space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black - - -def _draw_discrete_space_grid(space: Grid, space_ax, agent_portrayal): - if space._ndims != 2: - raise ValueError("Space must be 2D") - - def portray(g): - x = [] - y = [] - s = [] # size - c = [] # color - - for cell in g.all_cells: - for agent in cell.agents: - data = agent_portrayal(agent) - x.append(cell.coordinate[0]) - y.append(cell.coordinate[1]) - if "size" in data: - s.append(data["size"]) - if "color" in data: - c.append(data["color"]) - out = {"x": x, "y": y} - out["s"] = s - if len(c) > 0: - out["c"] = c - - return out - - space_ax.set_xlim(0, space.width) - space_ax.set_ylim(0, space.height) - - # Draw grid lines - for x in range(space.width + 1): - space_ax.axvline(x, color="gray", linestyle=":") - for y in range(space.height + 1): - space_ax.axhline(y, color="gray", linestyle=":") - - space_ax.scatter(**portray(space)) + ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black + + return ax + + +def _scatter(ax: Axes, arguments): + """Helper function for plotting the agents.""" + loc = arguments.pop("loc") + + x = loc[:, 0] + y = loc[:, 1] + marker = arguments.pop("marker") + zorder = arguments.pop("zorder") + + for mark in np.unique(marker): + mark_mask = marker == mark + for z_order in np.unique(zorder): + zorder_mask = z_order == zorder + logical = mark_mask & zorder_mask + ax.scatter( + x[logical], + y[logical], + marker=mark, + zorder=z_order, + **{k: v[logical] for k, v in arguments.items()}, + ) def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): diff --git a/tests/test_components_matplotlib.py b/tests/test_components_matplotlib.py new file mode 100644 index 00000000000..c85dd1ce292 --- /dev/null +++ b/tests/test_components_matplotlib.py @@ -0,0 +1,158 @@ +"""tests for matplotlib components.""" + +import matplotlib.pyplot as plt + +from mesa import Agent, Model +from mesa.experimental.cell_space import ( + CellAgent, + HexGrid, + Network, + OrthogonalMooreGrid, + VoronoiGrid, +) +from mesa.space import ( + ContinuousSpace, + HexSingleGrid, + NetworkGrid, + PropertyLayer, + SingleGrid, +) +from mesa.visualization.components.matplotlib import ( + draw_continuous_space, + draw_hex_grid, + draw_network, + draw_orthogonal_grid, + draw_property_layers, + draw_voroinoi_grid, +) + + +def agent_portrayal(agent): + """Simple portrayal of an agent. + + Args: + agent (Agent): The agent to portray + + """ + return { + "s": 10, + "c": "tab:blue", + "marker": "s" if (agent.unique_id % 2) == 0 else "o", + } + + +def test_draw_hex_grid(): + """Test drawing hexgrids.""" + model = Model(seed=42) + grid = HexSingleGrid(10, 10, torus=True) + for _ in range(10): + agent = Agent(model) + grid.move_to_empty(agent) + + fig, ax = plt.subplots() + draw_hex_grid(grid, agent_portrayal, ax) + + model = Model(seed=42) + grid = HexGrid((10, 10), torus=True, random=model.random, capacity=1) + for _ in range(10): + agent = CellAgent(model) + agent.cell = grid.select_random_empty_cell() + + fig, ax = plt.subplots() + draw_hex_grid(grid, agent_portrayal, ax) + + +def test_draw_voroinoi_grid(): + """Test drawing voroinoi grids.""" + model = Model(seed=42) + + coordinates = model.rng.random((100, 2)) * 10 + + grid = VoronoiGrid(coordinates.tolist(), random=model.random, capacity=1) + for _ in range(10): + agent = CellAgent(model) + agent.cell = grid.select_random_empty_cell() + + fig, ax = plt.subplots() + draw_voroinoi_grid(grid, agent_portrayal, ax) + + +def test_draw_orthogonal_grid(): + """Test drawing orthogonal grids.""" + model = Model(seed=42) + grid = SingleGrid(10, 10, torus=True) + for _ in range(10): + agent = Agent(model) + grid.move_to_empty(agent) + + fig, ax = plt.subplots() + draw_orthogonal_grid(grid, agent_portrayal, ax) + + model = Model(seed=42) + grid = OrthogonalMooreGrid((10, 10), torus=True, random=model.random, capacity=1) + for _ in range(10): + agent = CellAgent(model) + agent.cell = grid.select_random_empty_cell() + + fig, ax = plt.subplots() + draw_orthogonal_grid(grid, agent_portrayal, ax) + + +def test_draw_continuous_space(): + """Test drawing continuous space.""" + model = Model(seed=42) + space = ContinuousSpace(10, 10, torus=True) + for _ in range(10): + x = model.random.random() * 10 + y = model.random.random() * 10 + agent = Agent(model) + space.place_agent(agent, (x, y)) + + fig, ax = plt.subplots() + draw_continuous_space(space, agent_portrayal, ax) + + +def test_draw_network(): + """Test drawing network.""" + import networkx as nx + + n = 10 + m = 20 + seed = 42 + graph = nx.gnm_random_graph(n, m, seed=seed) + + model = Model(seed=42) + grid = NetworkGrid(graph) + for _ in range(10): + agent = Agent(model) + pos = agent.random.randint(0, len(graph.nodes) - 1) + grid.place_agent(agent, pos) + + fig, ax = plt.subplots() + draw_network(grid, agent_portrayal, ax) + + model = Model(seed=42) + grid = Network(graph, random=model.random, capacity=1) + for _ in range(10): + agent = CellAgent(model) + agent.cell = grid.select_random_empty_cell() + + fig, ax = plt.subplots() + draw_network(grid, agent_portrayal, ax) + + +def test_draw_property_layers(): + """Test drawing property layers.""" + model = Model(seed=42) + grid = SingleGrid(10, 10, torus=True) + grid.add_property_layer(PropertyLayer("test", grid.width, grid.height, 0)) + + fig, ax = plt.subplots() + draw_property_layers(grid, {"test": {"colormap": "viridis", "colorbar": True}}, ax) + + model = Model(seed=42) + grid = OrthogonalMooreGrid((10, 10), torus=True, random=model.random, capacity=1) + grid.add_property_layer(PropertyLayer("test", grid.width, grid.height, 0)) + + fig, ax = plt.subplots() + draw_property_layers(grid, {"test": {"colormap": "viridis", "colorbar": True}}, ax) diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index a0d2b449399..af6badd0bc2 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -8,7 +8,7 @@ import mesa import mesa.visualization.components.altair import mesa.visualization.components.matplotlib -from mesa.visualization.components.matplotlib import make_space_matplotlib +from mesa.visualization.components.matplotlib import make_space_component from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs @@ -97,16 +97,16 @@ def test_call_space_drawer(mocker): # noqa: D103 mocker.patch.object(mesa.Model, "__init__", return_value=None) agent_portrayal = { - "Shape": "circle", + "marker": "circle", "color": "gray", } propertylayer_portrayal = None # initialize with space drawer unspecified (use default) # component must be rendered for code to run - solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)])) + solara.render(SolaraViz(model, components=[make_space_component(agent_portrayal)])) # should call default method with class instance and agent portrayal mock_space_matplotlib.assert_called_with( - model, agent_portrayal, propertylayer_portrayal + model, agent_portrayal, propertylayer_portrayal, post_process=None ) # specify no space should be drawn @@ -132,7 +132,7 @@ def drawer(model): centroids_coordinates=[(0, 1), (0, 0), (1, 0)], ) solara.render( - SolaraViz(voronoi_model, components=[make_space_matplotlib(agent_portrayal)]) + SolaraViz(voronoi_model, components=[make_space_component(agent_portrayal)]) )