From 566b72facee68c77b77de1b31d967b15a3f9d66b Mon Sep 17 00:00:00 2001 From: nissu99 Date: Sun, 26 Jan 2025 11:22:15 +0530 Subject: [PATCH 01/29] renamed make_space_altair --- mesa/visualization/components/altair_components.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index b610e46f0d0..a0df903d511 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -15,12 +15,16 @@ def make_space_altair(*args, **kwargs): # noqa: D103 warnings.warn( - "make_space_altair has been renamed to make_altair_space", + "make_space_altair has been renamed to make_altair_space_component", DeprecationWarning, stacklevel=2, ) - return make_altair_space(*args, **kwargs) + return make_altair_space_component(*args, **kwargs) + +def make_altair_space_component(*args, **kwargs): + + return make_altair_space(*args, **kwargs) def make_altair_space( agent_portrayal, propertylayer_portrayal, post_process, **space_drawing_kwargs From e1f947c32e0e6e154071156b621c2b48db94675d Mon Sep 17 00:00:00 2001 From: nissu99 Date: Sun, 26 Jan 2025 11:50:50 +0530 Subject: [PATCH 02/29] added make_altair_plot_component support --- mesa/visualization/__init__.py | 3 +- mesa/visualization/components/__init__.py | 4 +- .../components/altair_components.py | 102 +++++++++++++++++- 3 files changed, 105 insertions(+), 4 deletions(-) diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index 5e6a95e676b..10e2b1c8c38 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -11,7 +11,7 @@ ) from .components import make_plot_component, make_space_component -from .components.altair_components import make_space_altair +from .components.altair_components import make_space_altair,make_altair_plot_component from .solara_viz import JupyterViz, SolaraViz from .user_param import Slider @@ -22,5 +22,6 @@ "draw_space", "make_plot_component", "make_space_altair", + "make_altair_plot_component", "make_space_component", ] diff --git a/mesa/visualization/components/__init__.py b/mesa/visualization/components/__init__.py index 4b70fc2b97c..b17bd6a9dbb 100644 --- a/mesa/visualization/components/__init__.py +++ b/mesa/visualization/components/__init__.py @@ -4,7 +4,7 @@ from collections.abc import Callable -from .altair_components import SpaceAltair, make_altair_space +from .altair_components import SpaceAltair, make_altair_space,make_altair_plot_component from .matplotlib_components import ( SpaceMatplotlib, make_mpl_plot_component, @@ -76,7 +76,7 @@ def make_plot_component( if backend == "matplotlib": return make_mpl_plot_component(measure, post_process, **plot_drawing_kwargs) elif backend == "altair": - raise NotImplementedError("altair line plots are not yet implemented") + return make_altair_plot_component(measure, post_process, **plot_drawing_kwargs) else: raise ValueError( f"unknown backend {backend}, must be one of matplotlib, altair" diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index a0df903d511..f6cb1da7d59 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,10 +1,11 @@ """Altair based solara components for visualization mesa spaces.""" import contextlib +from typing import Callable import warnings - import solara + with contextlib.suppress(ImportError): import altair as alt @@ -55,6 +56,105 @@ def MakeSpaceAltair(model): return MakeSpaceAltair + +def make_altair_plot_component( + measure: str | dict[str, str] | list[str] | tuple[str], + post_process: Callable | None = None, + width: int = 500, + height: int = 300, +): + """Create an Altair plotting component for specified measures. + + Args: + measure: Measure(s) to plot. Can be: + - str: Single measure name + - dict: Mapping of measure names to colors + - list/tuple: Multiple measure names + post_process: Optional callable for chart post-processing + width: Chart width in pixels + height: Chart height in pixels + + Returns: + function: A function that creates a PlotAltair component + """ + def MakePlotAltair(model): + return PlotAltair( + model, + measure, + post_process=post_process, + width=width, + height=height + ) + return MakePlotAltair + + +@solara.component +def PlotAltair( + model, + measure: str | dict[str, str] | list[str] | tuple[str], + post_process: Callable[[alt.Chart], alt.Chart] | None = None, + width: int = 500, + height: int = 300, +) -> solara.FigureAltair: + """Create an Altair plot for model.""" + + update_counter.get() + df = model.datacollector.get_model_vars_dataframe().reset_index() + + if isinstance(measure, str): + # Single measure - no transformation needed + chart = alt.Chart(df).encode( + x='Step:Q', + y=alt.Y(f'{measure}:Q', title=measure), + tooltip=[alt.Tooltip('Step:Q'), alt.Tooltip(f'{measure}:Q')] + ).mark_line() + + elif isinstance(measure, (list, tuple)): + # Multiple measures - melt dataframe + value_vars = list(measure) + melted_df = df.melt('Step', value_vars=value_vars, + var_name='Measure', value_name='Value') + + chart = alt.Chart(melted_df).encode( + x='Step:Q', + y=alt.Y('Value:Q'), + color='Measure:N', + tooltip=['Step:Q', 'Value:Q', 'Measure:N'] + ).mark_line() + + elif isinstance(measure, dict): + # Dictionary with colors - melt dataframe + value_vars = list(measure.keys()) + melted_df = df.melt('Step', value_vars=value_vars, + var_name='Measure', value_name='Value') + + # Create color scale from measure dict + domain = list(measure.keys()) + range_ = list(measure.values()) + + chart = alt.Chart(melted_df).encode( + x='Step:Q', + y=alt.Y('Value:Q'), + color=alt.Color('Measure:N', scale=alt.Scale(domain=domain, range=range_)), + tooltip=['Step:Q', 'Value:Q', 'Measure:N'] + ).mark_line() + + else: + raise ValueError("Unsupported measure type") + + # Configure chart properties + chart = chart.properties( + width=width, + height=height + ).configure_axis( + grid=True + ) + + if post_process is not None: + chart = post_process(chart) + + return solara.FigureAltair(chart) + @solara.component def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): """Create an Altair-based space visualization component. From 9e3a62580c70aa52e16d76809e793743fd4513d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Jan 2025 06:35:37 +0000 Subject: [PATCH 03/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/__init__.py | 4 +- mesa/visualization/components/__init__.py | 6 +- .../components/altair_components.py | 92 ++++++++++--------- 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index 10e2b1c8c38..c43ad2008f7 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -11,7 +11,7 @@ ) from .components import make_plot_component, make_space_component -from .components.altair_components import make_space_altair,make_altair_plot_component +from .components.altair_components import make_altair_plot_component, make_space_altair from .solara_viz import JupyterViz, SolaraViz from .user_param import Slider @@ -20,8 +20,8 @@ "Slider", "SolaraViz", "draw_space", + "make_altair_plot_component", "make_plot_component", "make_space_altair", - "make_altair_plot_component", "make_space_component", ] diff --git a/mesa/visualization/components/__init__.py b/mesa/visualization/components/__init__.py index b17bd6a9dbb..58ce45b0bc1 100644 --- a/mesa/visualization/components/__init__.py +++ b/mesa/visualization/components/__init__.py @@ -4,7 +4,11 @@ from collections.abc import Callable -from .altair_components import SpaceAltair, make_altair_space,make_altair_plot_component +from .altair_components import ( + SpaceAltair, + make_altair_plot_component, + make_altair_space, +) from .matplotlib_components import ( SpaceMatplotlib, make_mpl_plot_component, diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index f6cb1da7d59..59f28f0fa97 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,10 +1,10 @@ """Altair based solara components for visualization mesa spaces.""" import contextlib -from typing import Callable import warnings -import solara +from collections.abc import Callable +import solara with contextlib.suppress(ImportError): import altair as alt @@ -24,9 +24,9 @@ def make_space_altair(*args, **kwargs): # noqa: D103 def make_altair_space_component(*args, **kwargs): - return make_altair_space(*args, **kwargs) + def make_altair_space( agent_portrayal, propertylayer_portrayal, post_process, **space_drawing_kwargs ): @@ -56,7 +56,6 @@ def MakeSpaceAltair(model): return MakeSpaceAltair - def make_altair_plot_component( measure: str | dict[str, str] | list[str] | tuple[str], post_process: Callable | None = None, @@ -77,14 +76,12 @@ def make_altair_plot_component( Returns: function: A function that creates a PlotAltair component """ + def MakePlotAltair(model): return PlotAltair( - model, - measure, - post_process=post_process, - width=width, - height=height + model, measure, post_process=post_process, width=width, height=height ) + return MakePlotAltair @@ -97,64 +94,75 @@ def PlotAltair( height: int = 300, ) -> solara.FigureAltair: """Create an Altair plot for model.""" - update_counter.get() df = model.datacollector.get_model_vars_dataframe().reset_index() if isinstance(measure, str): # Single measure - no transformation needed - chart = alt.Chart(df).encode( - x='Step:Q', - y=alt.Y(f'{measure}:Q', title=measure), - tooltip=[alt.Tooltip('Step:Q'), alt.Tooltip(f'{measure}:Q')] - ).mark_line() + chart = ( + alt.Chart(df) + .encode( + x="Step:Q", + y=alt.Y(f"{measure}:Q", title=measure), + tooltip=[alt.Tooltip("Step:Q"), alt.Tooltip(f"{measure}:Q")], + ) + .mark_line() + ) elif isinstance(measure, (list, tuple)): # Multiple measures - melt dataframe value_vars = list(measure) - melted_df = df.melt('Step', value_vars=value_vars, - var_name='Measure', value_name='Value') - - chart = alt.Chart(melted_df).encode( - x='Step:Q', - y=alt.Y('Value:Q'), - color='Measure:N', - tooltip=['Step:Q', 'Value:Q', 'Measure:N'] - ).mark_line() + melted_df = df.melt( + "Step", value_vars=value_vars, var_name="Measure", value_name="Value" + ) + + chart = ( + alt.Chart(melted_df) + .encode( + x="Step:Q", + y=alt.Y("Value:Q"), + color="Measure:N", + tooltip=["Step:Q", "Value:Q", "Measure:N"], + ) + .mark_line() + ) elif isinstance(measure, dict): # Dictionary with colors - melt dataframe value_vars = list(measure.keys()) - melted_df = df.melt('Step', value_vars=value_vars, - var_name='Measure', value_name='Value') - + melted_df = df.melt( + "Step", value_vars=value_vars, var_name="Measure", value_name="Value" + ) + # Create color scale from measure dict domain = list(measure.keys()) range_ = list(measure.values()) - - chart = alt.Chart(melted_df).encode( - x='Step:Q', - y=alt.Y('Value:Q'), - color=alt.Color('Measure:N', scale=alt.Scale(domain=domain, range=range_)), - tooltip=['Step:Q', 'Value:Q', 'Measure:N'] - ).mark_line() - + + chart = ( + alt.Chart(melted_df) + .encode( + x="Step:Q", + y=alt.Y("Value:Q"), + color=alt.Color( + "Measure:N", scale=alt.Scale(domain=domain, range=range_) + ), + tooltip=["Step:Q", "Value:Q", "Measure:N"], + ) + .mark_line() + ) + else: raise ValueError("Unsupported measure type") # Configure chart properties - chart = chart.properties( - width=width, - height=height - ).configure_axis( - grid=True - ) + chart = chart.properties(width=width, height=height).configure_axis(grid=True) if post_process is not None: chart = post_process(chart) - + return solara.FigureAltair(chart) + @solara.component def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): """Create an Altair-based space visualization component. From a88d8166210703a8037b971d375e044d37e3415f Mon Sep 17 00:00:00 2001 From: Animesh Rawat <131552285+nissu99@users.noreply.github.com> Date: Mon, 27 Jan 2025 13:15:01 +0530 Subject: [PATCH 04/29] Update altair_components.py "alt defined" --- mesa/visualization/components/altair_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 59f28f0fa97..df2fcbef7de 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -5,7 +5,7 @@ from collections.abc import Callable import solara - +alt=None with contextlib.suppress(ImportError): import altair as alt From c579f81921f1a71f587114a9c97e3c467f7179a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 07:45:12 +0000 Subject: [PATCH 05/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/components/altair_components.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index df2fcbef7de..35d5e797c8a 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -5,7 +5,8 @@ from collections.abc import Callable import solara -alt=None + +alt = None with contextlib.suppress(ImportError): import altair as alt From 95a607b1f455884026977422c97ac9d3cab1fe33 Mon Sep 17 00:00:00 2001 From: Animesh Rawat <131552285+nissu99@users.noreply.github.com> Date: Mon, 27 Jan 2025 23:30:55 +0530 Subject: [PATCH 06/29] Update altair_components.py "removed a line" --- mesa/visualization/components/altair_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 35d5e797c8a..59f28f0fa97 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -6,7 +6,6 @@ import solara -alt = None with contextlib.suppress(ImportError): import altair as alt From 8e1779b11422eec04d10e62366d38a0bc7bb0293 Mon Sep 17 00:00:00 2001 From: Animesh Rawat <131552285+nissu99@users.noreply.github.com> Date: Mon, 27 Jan 2025 23:35:01 +0530 Subject: [PATCH 07/29] Update pyproject.toml "updated toml" --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index be0dc4a3139..5e34996eb93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ network = [ viz = [ "matplotlib", "solara", + "altair" ] # Dev and CI stuff dev = [ From 2b69a621d05bc2690294d4d20a33ba054807d7b6 Mon Sep 17 00:00:00 2001 From: Animesh Rawat <131552285+nissu99@users.noreply.github.com> Date: Mon, 27 Jan 2025 23:43:00 +0530 Subject: [PATCH 08/29] Update altair_components.py "removed suppress" --- mesa/visualization/components/altair_components.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 59f28f0fa97..106c7c24565 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,13 +1,10 @@ """Altair based solara components for visualization mesa spaces.""" -import contextlib import warnings from collections.abc import Callable import solara -with contextlib.suppress(ImportError): - import altair as alt from mesa.experimental.cell_space import DiscreteSpace, Grid from mesa.space import ContinuousSpace, _Grid From b5613f1a4d0c8b569d02e4c2a1f560a8355b95a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:13:07 +0000 Subject: [PATCH 09/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/components/altair_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 106c7c24565..b5ccbf8663a 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -5,7 +5,6 @@ import solara - from mesa.experimental.cell_space import DiscreteSpace, Grid from mesa.space import ContinuousSpace, _Grid from mesa.visualization.utils import update_counter From 0d1b8beb701c7062dfea986131763b54f013ff54 Mon Sep 17 00:00:00 2001 From: Animesh Rawat <131552285+nissu99@users.noreply.github.com> Date: Mon, 27 Jan 2025 23:49:34 +0530 Subject: [PATCH 10/29] Update altair_components.py "doc string " --- .../components/altair_components.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index b5ccbf8663a..019c23050e6 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,10 +1,13 @@ """Altair based solara components for visualization mesa spaces.""" - +import contextlib import warnings from collections.abc import Callable import solara +with contextlib.suppress(ImportError): + import altair as alt + from mesa.experimental.cell_space import DiscreteSpace, Grid from mesa.space import ContinuousSpace, _Grid from mesa.visualization.utils import update_counter @@ -20,6 +23,18 @@ def make_space_altair(*args, **kwargs): # noqa: D103 def make_altair_space_component(*args, **kwargs): + """Create an Altair-based space visualization component. + + Args: + *args: Positional arguments passed to make_altair_space + **kwargs: Keyword arguments passed to make_altair_space + + Returns: + function: A function that creates an Altair space visualization component + + See Also: + make_altair_space: The underlying implementation + """ return make_altair_space(*args, **kwargs) From ed8ca7104495c549f90d4965e9666128c81c4ceb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:19:41 +0000 Subject: [PATCH 11/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/components/altair_components.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 019c23050e6..0f3a61b5256 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,4 +1,5 @@ """Altair based solara components for visualization mesa spaces.""" + import contextlib import warnings from collections.abc import Callable @@ -7,7 +8,7 @@ with contextlib.suppress(ImportError): import altair as alt - + from mesa.experimental.cell_space import DiscreteSpace, Grid from mesa.space import ContinuousSpace, _Grid from mesa.visualization.utils import update_counter From da8fd536f8c72951aa3c55a06682cf1f1f60bd5b Mon Sep 17 00:00:00 2001 From: Animesh Rawat <131552285+nissu99@users.noreply.github.com> Date: Tue, 28 Jan 2025 00:08:46 +0530 Subject: [PATCH 12/29] Update altair_components.py "operator bug" --- mesa/visualization/components/altair_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 0f3a61b5256..c0a26994612 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -121,7 +121,7 @@ def PlotAltair( .mark_line() ) - elif isinstance(measure, (list, tuple)): + elif isinstance(measure, list | tuple): # Multiple measures - melt dataframe value_vars = list(measure) melted_df = df.melt( From 4a38eef2b404b381a78d81ada4a167b758436bf6 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Fri, 31 Jan 2025 16:17:37 +0530 Subject: [PATCH 13/29] blank --- .../components/altair_components.py | 157 ++++++++++-------- 1 file changed, 86 insertions(+), 71 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index c0a26994612..e4636dba7c4 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,16 +1,19 @@ """Altair based solara components for visualization mesa spaces.""" import contextlib +import math import warnings from collections.abc import Callable import solara +import mesa + with contextlib.suppress(ImportError): import altair as alt -from mesa.experimental.cell_space import DiscreteSpace, Grid -from mesa.space import ContinuousSpace, _Grid +from mesa.experimental.cell_space import Grid, HexGrid +from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter @@ -53,7 +56,6 @@ def make_altair_space( ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", "size", "marker", and "zorder". Other field are ignored and will result in a user warning. - Returns: function: A function that creates a SpaceMatplotlib component """ @@ -192,90 +194,103 @@ def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): solara.FigureAltair(chart) -def _get_agent_data_old__discrete_space(space, agent_portrayal): - """Format agent portrayal data for old-style discrete spaces. - - Args: - space: the mesa.space._Grid instance - agent_portrayal: the agent portrayal callable - - Returns: - list of dicts - - """ - all_agent_data = [] - for content, (x, y) in space.coord_iter(): - if not content: - continue - if not hasattr(content, "__iter__"): - # Is a single grid - content = [content] # noqa: PLW2901 - for agent in content: - # use all data from agent portrayal, and add x,y coordinates - agent_data = agent_portrayal(agent) - agent_data["x"] = x - agent_data["y"] = y - all_agent_data.append(agent_data) - return all_agent_data +def axial_to_pixel(q, r, size=1): + """Convert axial coordinates (q, r) to pixel coordinates for hexagonal grid.""" + x = size * math.sqrt(3) * (q + r / 2) + y = size * 1.5 * r + return x, y -def _get_agent_data_new_discrete_space(space: DiscreteSpace, agent_portrayal): - """Format agent portrayal data for new-style discrete spaces. +def get_agent_data(space, agent_portrayal): + """Generic method to extract agent data for visualization across all space types. Args: - space: the mesa.experiment.cell_space.Grid instance - agent_portrayal: the agent portrayal callable + space: Mesa space object + agent_portrayal: Function defining agent visualization properties Returns: - list of dicts - + List of agent data dictionaries with coordinates """ all_agent_data = [] - for cell in space.all_cells: - for agent in cell.agents: - agent_data = agent_portrayal(agent) - agent_data["x"] = cell.coordinate[0] - agent_data["y"] = cell.coordinate[1] - all_agent_data.append(agent_data) - return all_agent_data + # New DiscreteSpace + if isinstance(space, Grid): + for cell in space.all_cells: + for agent in cell.agents: + data = agent_portrayal(agent) + data.update({"x": cell.coordinate[0], "y": cell.coordinate[1]}) + all_agent_data.append(data) + + # Legacy Grid + elif isinstance(space, _Grid): + for content, (x, y) in space.coord_iter(): + if not content: + continue + agents = [content] if not hasattr(content, "__iter__") else content + for agent in agents: + data = agent_portrayal(agent) + data.update({"x": x, "y": y}) + all_agent_data.append(data) + + elif isinstance(space, HexGrid): + for content, (q, r) in space.coord_iter(): + if content: + for agent in content: + data = agent_portrayal(agent) + x, y = axial_to_pixel(q, r) + data.update({"x": x, "y": y}) + all_agent_data.append(data) + + elif isinstance(space, NetworkGrid): + for node in space.G.nodes(): + agents = space.G.nodes[node].get("agent", []) + if isinstance(agents, list): + agent_list = agents + else: + agent_list = [agents] if agents else [] + + for agent in agent_list: + if agent: + pos = space.G.nodes[node].get("pos", (0, 0)) + data = agent_portrayal(agent) + data.update({"x": pos[0], "y": pos[1]}) + all_agent_data.append(data) + + elif isinstance( + space, ContinuousSpace | mesa.experimental.continuous_space.ContinuousSpace + ): + for agent in space.agents: + data = agent_portrayal(agent) + data.update({"x": agent.pos[0], "y": agent.pos[1]}) + all_agent_data.append(data) + else: + raise NotImplementedError(f"Unsupported space type: {type(space)}") -def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal): - """Format agent portrayal data for continuous space. - - Args: - space: the ContinuousSpace instance - agent_portrayal: the agent portrayal callable - - Returns: - list of dicts - """ - all_agent_data = [] - for agent in space._agent_to_index: - agent_data = agent_portrayal(agent) - agent_data["x"] = agent.pos[0] - agent_data["y"] = agent.pos[1] - all_agent_data.append(agent_data) return all_agent_data def _draw_grid(space, agent_portrayal): - match space: - case Grid(): - all_agent_data = _get_agent_data_new_discrete_space(space, agent_portrayal) - case _Grid(): - all_agent_data = _get_agent_data_old__discrete_space(space, agent_portrayal) - case ContinuousSpace(): - all_agent_data = _get_agent_data_continuous_space(space, agent_portrayal) - case _: - raise NotImplementedError( - f"visualizing {type(space)} is currently not supported through altair" - ) + """Create Altair visualization for any supported space type.""" + all_agent_data = get_agent_data(space, agent_portrayal) + + # Handle empty state + if not all_agent_data: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) invalid_tooltips = ["color", "size", "x", "y"] - x_y_type = "ordinal" if not isinstance(space, ContinuousSpace) else "nominal" + x_y_type = ( + "quantitative" + if isinstance( + space, + ContinuousSpace + | HexGrid + | NetworkGrid + | mesa.experimental.continuous_space.ContinuousSpace, + ) + else "ordinal" + ) encoding_dict = { # no x-axis label @@ -283,7 +298,7 @@ def _draw_grid(space, agent_portrayal): # no y-axis label "y": alt.Y("y", axis=None, type=x_y_type), "tooltip": [ - alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value])) + alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() if key not in invalid_tooltips ], @@ -305,7 +320,7 @@ def _draw_grid(space, agent_portrayal): ) # This is the default value for the marker size, which auto-scales # according to the grid area. - if not has_size: + if not has_size and isinstance(space, _Grid | Grid): length = min(space.width, space.height) chart = chart.mark_point(size=30000 / length**2, filled=True) From c49f864a1e407a7df84cb93e8de0042600f689ed Mon Sep 17 00:00:00 2001 From: nissu99 Date: Fri, 31 Jan 2025 19:29:26 +0530 Subject: [PATCH 14/29] resolved conflicts --- .../components/altair_components.py | 16 +++++-- mesa/visualization/solara_viz.py | 6 ++- pyproject.toml | 2 +- tests/test_solara_viz.py | 45 ++++++++++++++++--- 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index e4636dba7c4..0996037ed11 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -5,14 +5,16 @@ import warnings from collections.abc import Callable +import altair as alt import solara import mesa +from mesa.experimental.cell_space.grid import HexGrid with contextlib.suppress(ImportError): import altair as alt -from mesa.experimental.cell_space import Grid, HexGrid +from mesa.experimental.cell_space import Grid from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter @@ -50,7 +52,7 @@ def make_altair_space( Args: agent_portrayal: Function to portray agents. propertylayer_portrayal: not yet implemented - post_process :not yet implemented + post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks) space_drawing_kwargs : not yet implemented ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", @@ -65,7 +67,7 @@ def agent_portrayal(a): return {"id": a.unique_id} def MakeSpaceAltair(model): - return SpaceAltair(model, agent_portrayal) + return SpaceAltair(model, agent_portrayal, post_process=post_process) return MakeSpaceAltair @@ -178,7 +180,9 @@ def PlotAltair( @solara.component -def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): +def SpaceAltair( + model, agent_portrayal, dependencies: list[any] | None = None, post_process=None +): """Create an Altair-based space visualization component. Returns: @@ -191,6 +195,10 @@ def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): space = model.space chart = _draw_grid(space, agent_portrayal) + # Apply post-processing if provided + if post_process is not None: + chart = post_process(chart) + solara.FigureAltair(chart) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index f5fde84b1a3..23d603aa820 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -97,7 +97,11 @@ def SolaraViz( reduce update frequency,resulting in faster execution. """ if components == "default": - components = [components_altair.make_altair_space()] + components = [ + components_altair.make_altair_space( + agent_portrayal=None, propertylayer_portrayal=None, post_process=None + ) + ] if model_params is None: model_params = {} diff --git a/pyproject.toml b/pyproject.toml index 5e34996eb93..b9739c917da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ network = [ viz = [ "matplotlib", "solara", - "altair" + "altair", ] # Dev and CI stuff dev = [ diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 3b8d82fb7bc..6e25502e0b7 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -9,6 +9,8 @@ import mesa import mesa.visualization.components.altair_components import mesa.visualization.components.matplotlib_components +from mesa.space import MultiGrid +from mesa.visualization.components.altair_components import make_altair_space from mesa.visualization.components.matplotlib_components import make_mpl_space_component from mesa.visualization.solara_viz import ( Slider, @@ -101,17 +103,22 @@ def test_call_space_drawer(mocker): # noqa: D103 mesa.visualization.components.altair_components, "SpaceAltair" ) + class MockAgent(mesa.Agent): + def __init__(self, model): + super().__init__(model) + class MockModel(mesa.Model): def __init__(self, seed=None): super().__init__(seed=seed) + self.grid = MultiGrid(width=10, height=10, torus=True) + a = MockAgent(self) + self.grid.place_agent(a, (5, 5)) model = MockModel() - mocker.patch.object(mesa.Model, "__init__", return_value=None) - agent_portrayal = { - "marker": "circle", - "color": "gray", - } + def agent_portrayal(agent): + return {"marker": "o", "color": "gray"} + propertylayer_portrayal = None # initialize with space drawer unspecified (use default) # component must be rendered for code to run @@ -131,7 +138,33 @@ def __init__(self, seed=None): solara.render(SolaraViz(model)) # should call default method with class instance and agent portrayal assert mock_space_matplotlib.call_count == 0 - assert mock_space_altair.call_count == 0 + assert mock_space_altair.call_count == 1 # altair is the default method + + # checking if SpaceAltair is working as intended with post_process + + mock_post_process = mocker.MagicMock() + solara.render( + SolaraViz( + model, + components=[ + make_altair_space( + agent_portrayal, + propertylayer_portrayal, + mock_post_process, + ) + ], + ) + ) + + args, kwargs = mock_space_altair.call_args + assert args == (model, agent_portrayal) + assert kwargs == {"post_process": mock_post_process} + mock_post_process.assert_called_once() + assert mock_space_matplotlib.call_count == 0 + + mock_space_altair.reset_mock() + mock_space_matplotlib.reset_mock() + mock_post_process.reset_mock() # specify a custom space method class AltSpace: From 0621810f6d5b143d2e0ffd511bc65fbae26cfeed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 14:49:05 +0000 Subject: [PATCH 15/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/components/altair_components.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index d6aedf86d30..0996037ed11 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -2,7 +2,6 @@ import contextlib import math - import warnings from collections.abc import Callable @@ -17,7 +16,6 @@ from mesa.experimental.cell_space import Grid from mesa.space import ContinuousSpace, NetworkGrid, _Grid - from mesa.visualization.utils import update_counter @@ -104,7 +102,6 @@ def MakePlotAltair(model): @solara.component - def PlotAltair( model, measure: str | dict[str, str] | list[str] | tuple[str], From b024b329befd91154b46708d0b4402a7e7b79c67 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Tue, 11 Feb 2025 17:21:43 +0530 Subject: [PATCH 16/29] resolves --- .../components/altair_components.py | 111 ++++++++++-------- 1 file changed, 59 insertions(+), 52 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 0996037ed11..ae12549e84e 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -5,11 +5,11 @@ import warnings from collections.abc import Callable -import altair as alt import solara import mesa -from mesa.experimental.cell_space.grid import HexGrid +import mesa.experimental +from mesa.experimental.cell_space import HexGrid with contextlib.suppress(ImportError): import altair as alt @@ -19,7 +19,11 @@ from mesa.visualization.utils import update_counter -def make_space_altair(*args, **kwargs): # noqa: D103 +def make_space_altair(*args, **kwargs): + """Create an Altair chart component for visualizing model space (deprecated). + + This function is deprecated. Use make_altair_space_component instead. + """ warnings.warn( "make_space_altair has been renamed to make_altair_space_component", DeprecationWarning, @@ -210,70 +214,73 @@ def axial_to_pixel(q, r, size=1): def get_agent_data(space, agent_portrayal): - """Generic method to extract agent data for visualization across all space types. + """Draw a Matplotlib-based visualization of the space. Args: - space: Mesa space object - agent_portrayal: Function defining agent visualization properties + space: the space of the mesa model + agent_portrayal: A callable that returns a dict specifying how to show the agent Returns: List of agent data dictionaries with coordinates """ all_agent_data = [] - # New DiscreteSpace - if isinstance(space, Grid): - for cell in space.all_cells: - for agent in cell.agents: - data = agent_portrayal(agent) - data.update({"x": cell.coordinate[0], "y": cell.coordinate[1]}) - all_agent_data.append(data) - - # Legacy Grid - elif isinstance(space, _Grid): - for content, (x, y) in space.coord_iter(): - if not content: - continue - agents = [content] if not hasattr(content, "__iter__") else content - for agent in agents: - data = agent_portrayal(agent) - data.update({"x": x, "y": y}) - all_agent_data.append(data) - - elif isinstance(space, HexGrid): - for content, (q, r) in space.coord_iter(): - if content: - for agent in content: + match space: + case Grid(): + # New DiscreteSpace or experimental cell space + for cell in space.all_cells: + for agent in cell.agents: data = agent_portrayal(agent) - x, y = axial_to_pixel(q, r) - data.update({"x": x, "y": y}) + data.update({"x": cell.coordinate[0], "y": cell.coordinate[1]}) all_agent_data.append(data) - elif isinstance(space, NetworkGrid): - for node in space.G.nodes(): - agents = space.G.nodes[node].get("agent", []) - if isinstance(agents, list): - agent_list = agents - else: - agent_list = [agents] if agents else [] - - for agent in agent_list: - if agent: - pos = space.G.nodes[node].get("pos", (0, 0)) + case _Grid(): + # Legacy MESA grid + for content, (x, y) in space.coord_iter(): + if not content: + continue + agents = [content] if not hasattr(content, "__iter__") else content + for agent in agents: data = agent_portrayal(agent) - data.update({"x": pos[0], "y": pos[1]}) + data.update({"x": x, "y": y}) all_agent_data.append(data) - elif isinstance( - space, ContinuousSpace | mesa.experimental.continuous_space.ContinuousSpace - ): - for agent in space.agents: - data = agent_portrayal(agent) - data.update({"x": agent.pos[0], "y": agent.pos[1]}) - all_agent_data.append(data) + case HexGrid(): + # Hex-based grid + for content, (q, r) in space.coord_iter(): + if content: + for agent in content: + data = agent_portrayal(agent) + x, y = axial_to_pixel(q, r) + data.update({"x": x, "y": y}) + all_agent_data.append(data) + + case NetworkGrid(): + # Network grid (graph-based) + for node in space.G.nodes(): + node_agents = space.G.nodes[node].get("agent", []) + if not isinstance(node_agents, list): + node_agents = [node_agents] if node_agents else [] + for agent in node_agents: + if agent: + pos = space.G.nodes[node].get("pos", (0, 0)) + data = agent_portrayal(agent) + data.update({"x": pos[0], "y": pos[1]}) + all_agent_data.append(data) + + case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): + # Continuous space (including experimental version) + for agent in space.agents: + data = agent_portrayal(agent) + if hasattr(agent, "pos") and agent.pos is not None: + data.update({"x": agent.pos[0], "y": agent.pos[1]}) + else: + data.update({"x": 0, "y": 0}) + all_agent_data.append(data) - else: - raise NotImplementedError(f"Unsupported space type: {type(space)}") + case _: + # Fallback for unrecognized space types + raise NotImplementedError(f"Unsupported space type: {type(space)}") return all_agent_data From 0a37307afc27175c6cca9edd731028d9e3b7e3d7 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Sun, 16 Feb 2025 16:05:50 +0530 Subject: [PATCH 17/29] changes regarding updating agent position and docs --- mesa/visualization/components/__init__.py | 3 - .../components/altair_components.py | 298 +++++++++++++++--- 2 files changed, 260 insertions(+), 41 deletions(-) diff --git a/mesa/visualization/components/__init__.py b/mesa/visualization/components/__init__.py index 58ce45b0bc1..e145fb1cb63 100644 --- a/mesa/visualization/components/__init__.py +++ b/mesa/visualization/components/__init__.py @@ -71,9 +71,6 @@ def make_plot_component( backend: the backend to use {"matplotlib", "altair"} plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component - Notes: - altair plotting backend is not yet implemented and planned for mesa 3.1. - Returns: function: A function that creates a plot component """ diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index ae12549e84e..e1758a737fb 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -63,7 +63,7 @@ def make_altair_space( "size", "marker", and "zorder". Other field are ignored and will result in a user warning. Returns: - function: A function that creates a SpaceMatplotlib component + function: A function that creates a SpaceAltair component """ if agent_portrayal is None: @@ -113,7 +113,21 @@ def PlotAltair( width: int = 500, height: int = 300, ) -> solara.FigureAltair: - """Create an Altair plot for model.""" + """Create an Altair plot for model data. + + Args: + model: The mesa.Model instance. + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + post_process: An optional callable that takes an Altair Chart object as + input and returns a modified Chart object. This allows + for customization of the plot (e.g., adding annotations, + changing axis labels). + width: The width of the chart in pixels. + height: The height of the chart in pixels. + + Returns: + A solara.FigureAltair component displaying the generated Altair chart.""" + update_counter.get() df = model.datacollector.get_model_vars_dataframe().reset_index() @@ -189,8 +203,30 @@ def SpaceAltair( ): """Create an Altair-based space visualization component. + Args: + model: The mesa.Model instance containing the space to visualize. + The model must have a `grid` or `space` attribute that + represents the space (e.g., Grid, ContinuousSpace, NetworkGrid). + agent_portrayal: A callable that takes an agent as input and returns + a dictionary specifying how the agent should be + visualized. The dictionary can contain the following keys: + - "color": A string representing the agent's color (e.g., "red", "#FF0000"). + - "size": A number representing the agent's size. + - "tooltip": A string to display as a tooltip when hovering over the agent. + - Any other Vega-Lite mark properties that are supported by Altair. + dependencies: A list of dependencies that trigger a re-render of the + component when they change. This can be used to update + the visualization when the model state changes. + post_process: An optional callable that takes an Altair Chart object + as input and returns a modified Chart object. This allows + for customization of the plot (e.g., adding annotations, + changing axis labels). + Returns: - a solara FigureAltair instance + A solara.FigureAltair instance, which is a Solara component that + renders the Altair chart. + + """ update_counter.get() space = getattr(model, "grid", None) @@ -214,11 +250,11 @@ def axial_to_pixel(q, r, size=1): def get_agent_data(space, agent_portrayal): - """Draw a Matplotlib-based visualization of the space. + """Generic method to extract agent data for visualization across all space types. Args: - space: the space of the mesa model - agent_portrayal: A callable that returns a dict specifying how to show the agent + space: Mesa space object + agent_portrayal: Function defining agent visualization properties Returns: List of agent data dictionaries with coordinates @@ -235,7 +271,7 @@ def get_agent_data(space, agent_portrayal): all_agent_data.append(data) case _Grid(): - # Legacy MESA grid + # Legacy Grid for content, (x, y) in space.coord_iter(): if not content: continue @@ -251,66 +287,119 @@ def get_agent_data(space, agent_portrayal): if content: for agent in content: data = agent_portrayal(agent) - x, y = axial_to_pixel(q, r) - data.update({"x": x, "y": y}) + data.update({"q": q, "r": r}) # Store axial coordinates all_agent_data.append(data) case NetworkGrid(): - # Network grid (graph-based) + # Network grid for node in space.G.nodes(): - node_agents = space.G.nodes[node].get("agent", []) - if not isinstance(node_agents, list): - node_agents = [node_agents] if node_agents else [] - for agent in node_agents: + agents = space.G.nodes[node].get("agent", []) + if not isinstance(agents, list): + agents = [agents] if agents else [] + + for agent in agents: if agent: - pos = space.G.nodes[node].get("pos", (0, 0)) data = agent_portrayal(agent) - data.update({"x": pos[0], "y": pos[1]}) + data.update({"node": node}) # Store node information all_agent_data.append(data) case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): - # Continuous space (including experimental version) + # Continuous space for agent in space.agents: data = agent_portrayal(agent) - if hasattr(agent, "pos") and agent.pos is not None: - data.update({"x": agent.pos[0], "y": agent.pos[1]}) - else: - data.update({"x": 0, "y": 0}) + data.update({"x": agent.pos[0], "y": agent.pos[1]}) all_agent_data.append(data) case _: - # Fallback for unrecognized space types raise NotImplementedError(f"Unsupported space type: {type(space)}") return all_agent_data def _draw_grid(space, agent_portrayal): - """Create Altair visualization for any supported space type.""" + + """Create Altair visualization for any supported space type. + + This function acts as a dispatcher, calling the appropriate + `_draw_*_grid` function based on the type of space provided. + + Args: + space: The mesa.space object to visualize (e.g., Grid, ContinuousSpace, + NetworkGrid). + agent_portrayal: A callable that takes an agent as input and returns + a dictionary specifying how the agent should be + visualized. + + Returns: + An Altair Chart object representing the visualization of the space + and its agents. Returns a text chart "No agents" if there are no agents. + + """ all_agent_data = get_agent_data(space, agent_portrayal) # Handle empty state if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + match space: + case Grid(): + return _draw_discrete_grid(space, all_agent_data, agent_portrayal) + case _Grid(): + return _draw_legacy_grid(space, all_agent_data, agent_portrayal) + case HexGrid(): + return _draw_hex_grid(space, all_agent_data, agent_portrayal) + case NetworkGrid(): + return _draw_network_grid(space, all_agent_data, agent_portrayal) + case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): + return _draw_continuous_space(space, all_agent_data, agent_portrayal) + case _: + raise NotImplementedError(f"Unsupported space type: {type(space)}") + + +def _draw_discrete_grid(space, all_agent_data, agent_portrayal): + """Create Altair visualization for Discrete Grid.""" invalid_tooltips = ["color", "size", "x", "y"] + x_y_type = "ordinal" - x_y_type = ( - "quantitative" - if isinstance( - space, - ContinuousSpace - | HexGrid - | NetworkGrid - | mesa.experimental.continuous_space.ContinuousSpace, + encoding_dict = { + "x": alt.X("x", axis=None, type=x_y_type), + "y": alt.Y("y", axis=None, type=x_y_type), + "tooltip": [ + alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) + for key, value in all_agent_data[0].items() + if key not in invalid_tooltips + ], + } + + has_color = "color" in all_agent_data[0] + if has_color: + encoding_dict["color"] = alt.Color("color", type="nominal") + has_size = "size" in all_agent_data[0] + if has_size: + encoding_dict["size"] = alt.Size("size", type="quantitative") + + chart = ( + alt.Chart( + alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) ) - else "ordinal" + .mark_point(filled=True) + .properties(width=280, height=280) ) + if not has_size: + length = min(space.width, space.height) + chart = chart.mark_point(size=30000 / length**2, filled=True) + + return chart + + +def _draw_legacy_grid(space, all_agent_data, agent_portrayal): + """Create Altair visualization for Legacy Grid.""" + invalid_tooltips = ["color", "size", "x", "y"] + x_y_type = "ordinal" + encoding_dict = { - # no x-axis label "x": alt.X("x", axis=None, type=x_y_type), - # no y-axis label "y": alt.Y("y", axis=None, type=x_y_type), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) @@ -318,6 +407,7 @@ def _draw_grid(space, agent_portrayal): if key not in invalid_tooltips ], } + has_color = "color" in all_agent_data[0] if has_color: encoding_dict["color"] = alt.Color("color", type="nominal") @@ -331,12 +421,144 @@ def _draw_grid(space, agent_portrayal): ) .mark_point(filled=True) .properties(width=280, height=280) - # .configure_view(strokeOpacity=0) # hide grid/chart lines ) - # This is the default value for the marker size, which auto-scales - # according to the grid area. - if not has_size and isinstance(space, _Grid | Grid): + + if not has_size: length = min(space.width, space.height) chart = chart.mark_point(size=30000 / length**2, filled=True) return chart + + +def _draw_hex_grid(space, all_agent_data, agent_portrayal): + """Create Altair visualization for Hex Grid.""" + invalid_tooltips = ["color", "size", "x", "y", "q", "r"] + x_y_type = "quantitative" + + # Parameters for hexagon grid + size = 1.0 + x_spacing = math.sqrt(3) * size + y_spacing = 1.5 * size + + # Calculate x, y coordinates from axial coordinates + for agent_data in all_agent_data: + q = agent_data.pop("q") + r = agent_data.pop("r") + x, y = axial_to_pixel(q, r) + agent_data["x"] = x + agent_data["y"] = y + + encoding_dict = { + "x": alt.X("x", axis=None, type=x_y_type), + "y": alt.Y("y", axis=None, type=x_y_type), + "tooltip": [ + alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) + for key, value in all_agent_data[0].items() + if key not in invalid_tooltips + ], + } + + has_color = "color" in all_agent_data[0] + if has_color: + encoding_dict["color"] = alt.Color("color", type="nominal") + has_size = "size" in all_agent_data[0] + if has_size: + encoding_dict["size"] = alt.Size("size", type="quantitative") + + chart = ( + alt.Chart( + alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) + ) + .mark_point(filled=True) + .properties(width=280, height=280) + ) + + # Calculate proper bounds that account for the full hexagon width and height + x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) + y_max = space.height * y_spacing + + # Add padding that accounts for the hexagon points + x_padding = ( + size * math.sqrt(3) / 2 + ) + y_padding = size + + chart = chart.properties( + xlim=(-2 * x_padding, x_max + x_padding), + ylim=(-2 * y_padding, y_max + y_padding), + ) + + return chart + + +def _draw_network_grid(space, all_agent_data, agent_portrayal): + """Create Altair visualization for Network Grid.""" + invalid_tooltips = ["color", "size", "x", "y", "node"] + x_y_type = "quantitative" + + # Get x, y coordinates from node positions + for agent_data in all_agent_data: + node = agent_data.pop("node") + pos = space.G.nodes[node].get("pos", (0, 0)) # Default to (0, 0) if no pos + agent_data["x"] = pos[0] + agent_data["y"] = pos[1] + + encoding_dict = { + "x": alt.X("x", axis=None, type=x_y_type), + "y": alt.Y("y", axis=None, type=x_y_type), + "tooltip": [ + alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) + for key, value in all_agent_data[0].items() + if key not in invalid_tooltips + ], + } + + has_color = "color" in all_agent_data[0] + if has_color: + encoding_dict["color"] = alt.Color("color", type="nominal") + has_size = "size" in all_agent_data[0] + if has_size: + encoding_dict["size"] = alt.Size("size", type="quantitative") + + chart = ( + alt.Chart( + alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) + ) + .mark_point(filled=True) + .properties(width=280, height=280) + ) + + return chart + + +def _draw_continuous_space(space, all_agent_data, agent_portrayal): + """Create Altair visualization for Continuous Space.""" + invalid_tooltips = ["color", "size", "x", "y"] + x_y_type = "quantitative" + + encoding_dict = { + "x": alt.X("x", axis=None, type=x_y_type), + "y": alt.Y("y", axis=None, type=x_y_type), + "tooltip": [ + alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) + for key, value in all_agent_data[0].items() + if key not in invalid_tooltips + ], + } + + has_color = "color" in all_agent_data[0] + if has_color: + encoding_dict["color"] = alt.Color("color", type="nominal") + has_size = "size" in all_agent_data[0] + if has_size: + encoding_dict["size"] = alt.Size("size", type="quantitative") + + chart = ( + alt.Chart( + alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) + ) + .mark_point(filled=True) + .properties(width=280, height=280) + ) + + return chart From 4d1bd066ba912ee46b0f553dcc13207181ab5a09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 16 Feb 2025 10:36:06 +0000 Subject: [PATCH 18/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/components/altair_components.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index e1758a737fb..e71305ac3b3 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -126,8 +126,8 @@ def PlotAltair( height: The height of the chart in pixels. Returns: - A solara.FigureAltair component displaying the generated Altair chart.""" - + A solara.FigureAltair component displaying the generated Altair chart. + """ update_counter.get() df = model.datacollector.get_model_vars_dataframe().reset_index() @@ -317,7 +317,6 @@ def get_agent_data(space, agent_portrayal): def _draw_grid(space, agent_portrayal): - """Create Altair visualization for any supported space type. This function acts as a dispatcher, calling the appropriate @@ -334,7 +333,7 @@ def _draw_grid(space, agent_portrayal): An Altair Chart object representing the visualization of the space and its agents. Returns a text chart "No agents" if there are no agents. - """ + """ all_agent_data = get_agent_data(space, agent_portrayal) # Handle empty state @@ -478,10 +477,8 @@ def _draw_hex_grid(space, all_agent_data, agent_portrayal): y_max = space.height * y_spacing # Add padding that accounts for the hexagon points - x_padding = ( - size * math.sqrt(3) / 2 - ) - y_padding = size + x_padding = size * math.sqrt(3) / 2 + y_padding = size chart = chart.properties( xlim=(-2 * x_padding, x_max + x_padding), From d6e96b0131b52811ece22be884e3d3aff111b31e Mon Sep 17 00:00:00 2001 From: nissu99 Date: Mon, 17 Feb 2025 20:32:25 +0530 Subject: [PATCH 19/29] resize fix --- .../components/altair_components.py | 81 ++++++++++++------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index e71305ac3b3..bfc5f14805f 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -336,10 +336,12 @@ def _draw_grid(space, agent_portrayal): """ all_agent_data = get_agent_data(space, agent_portrayal) + # Handle empty state if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + match space: case Grid(): return _draw_discrete_grid(space, all_agent_data, agent_portrayal) @@ -389,6 +391,11 @@ def _draw_discrete_grid(space, all_agent_data, agent_portrayal): length = min(space.width, space.height) chart = chart.mark_point(size=30000 / length**2, filled=True) + chart = chart.encode( + x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), + y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) + ) + return chart @@ -426,7 +433,12 @@ def _draw_legacy_grid(space, all_agent_data, agent_portrayal): length = min(space.width, space.height) chart = chart.mark_point(size=30000 / length**2, filled=True) - return chart + chart = chart.encode( + x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), + y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) + ) + + return chart def _draw_hex_grid(space, all_agent_data, agent_portrayal): @@ -447,9 +459,23 @@ def _draw_hex_grid(space, all_agent_data, agent_portrayal): agent_data["x"] = x agent_data["y"] = y + # Calculate proper bounds that account for the full hexagon width and height + x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) + y_max = space.height * y_spacing + + # Add padding that accounts for the hexagon points + x_padding = size * math.sqrt(3) / 2 + y_padding = size + + x_scale = alt.Scale(domain=(-2 * x_padding, x_max + x_padding)) + y_scale = alt.Scale(domain=(-2 * y_padding, y_max + y_padding)) + + + + encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type), - "y": alt.Y("y", axis=None, type=x_y_type), + "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -472,19 +498,6 @@ def _draw_hex_grid(space, all_agent_data, agent_portrayal): .properties(width=280, height=280) ) - # Calculate proper bounds that account for the full hexagon width and height - x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) - y_max = space.height * y_spacing - - # Add padding that accounts for the hexagon points - x_padding = size * math.sqrt(3) / 2 - y_padding = size - - chart = chart.properties( - xlim=(-2 * x_padding, x_max + x_padding), - ylim=(-2 * y_padding, y_max + y_padding), - ) - return chart @@ -493,16 +506,24 @@ def _draw_network_grid(space, all_agent_data, agent_portrayal): invalid_tooltips = ["color", "size", "x", "y", "node"] x_y_type = "quantitative" - # Get x, y coordinates from node positions - for agent_data in all_agent_data: - node = agent_data.pop("node") - pos = space.G.nodes[node].get("pos", (0, 0)) # Default to (0, 0) if no pos - agent_data["x"] = pos[0] - agent_data["y"] = pos[1] + # Get x, y coordinates and determine bounds + positions = [space.G.nodes[node].get("pos", (0, 0)) for node in space.G.nodes()] + x_values = [p[0] for p in positions] + y_values = [p[1] for p in positions] + + # Add padding to the bounds + padding = 0.1 # 10% padding + x_min, x_max = min(x_values), max(x_values) + y_min, y_max = min(y_values), max(y_values) + x_range = x_max - x_min + y_range = y_max - y_min + + x_scale = alt.Scale(domain=(x_min - padding * x_range, x_max + padding * x_range)) + y_scale = alt.Scale(domain=(y_min - padding * y_range, y_max + padding * y_range)) encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type), - "y": alt.Y("y", axis=None, type=x_y_type), + "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -523,9 +544,10 @@ def _draw_network_grid(space, all_agent_data, agent_portrayal): ) .mark_point(filled=True) .properties(width=280, height=280) + ) - return chart + return chart def _draw_continuous_space(space, all_agent_data, agent_portrayal): @@ -533,9 +555,13 @@ def _draw_continuous_space(space, all_agent_data, agent_portrayal): invalid_tooltips = ["color", "size", "x", "y"] x_y_type = "quantitative" + x_scale = alt.Scale(domain=(0, space.width)) + y_scale = alt.Scale(domain=(0, space.height)) + + encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type), - "y": alt.Y("y", axis=None, type=x_y_type), + "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -558,4 +584,5 @@ def _draw_continuous_space(space, all_agent_data, agent_portrayal): .properties(width=280, height=280) ) + return chart From 54daa24b89b7b00bc636feb302928f9d1fb8fc80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Feb 2025 15:04:32 +0000 Subject: [PATCH 20/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../components/altair_components.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index bfc5f14805f..a3e2b36713f 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -336,12 +336,10 @@ def _draw_grid(space, agent_portrayal): """ all_agent_data = get_agent_data(space, agent_portrayal) - # Handle empty state if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - match space: case Grid(): return _draw_discrete_grid(space, all_agent_data, agent_portrayal) @@ -392,8 +390,12 @@ def _draw_discrete_grid(space, all_agent_data, agent_portrayal): chart = chart.mark_point(size=30000 / length**2, filled=True) chart = chart.encode( - x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), - y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) + x=alt.X( + "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) + ), + y=alt.Y( + "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) + ), ) return chart @@ -434,11 +436,15 @@ def _draw_legacy_grid(space, all_agent_data, agent_portrayal): chart = chart.mark_point(size=30000 / length**2, filled=True) chart = chart.encode( - x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), - y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) + x=alt.X( + "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) + ), + y=alt.Y( + "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) + ), ) - return chart + return chart def _draw_hex_grid(space, all_agent_data, agent_portrayal): @@ -470,12 +476,9 @@ def _draw_hex_grid(space, all_agent_data, agent_portrayal): x_scale = alt.Scale(domain=(-2 * x_padding, x_max + x_padding)) y_scale = alt.Scale(domain=(-2 * y_padding, y_max + y_padding)) - - - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), + "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -510,14 +513,14 @@ def _draw_network_grid(space, all_agent_data, agent_portrayal): positions = [space.G.nodes[node].get("pos", (0, 0)) for node in space.G.nodes()] x_values = [p[0] for p in positions] y_values = [p[1] for p in positions] - + # Add padding to the bounds padding = 0.1 # 10% padding x_min, x_max = min(x_values), max(x_values) y_min, y_max = min(y_values), max(y_values) x_range = x_max - x_min y_range = y_max - y_min - + x_scale = alt.Scale(domain=(x_min - padding * x_range, x_max + padding * x_range)) y_scale = alt.Scale(domain=(y_min - padding * y_range, y_max + padding * y_range)) @@ -544,10 +547,9 @@ def _draw_network_grid(space, all_agent_data, agent_portrayal): ) .mark_point(filled=True) .properties(width=280, height=280) - ) - return chart + return chart def _draw_continuous_space(space, all_agent_data, agent_portrayal): @@ -558,10 +560,9 @@ def _draw_continuous_space(space, all_agent_data, agent_portrayal): x_scale = alt.Scale(domain=(0, space.width)) y_scale = alt.Scale(domain=(0, space.height)) - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), + "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -584,5 +585,4 @@ def _draw_continuous_space(space, all_agent_data, agent_portrayal): .properties(width=280, height=280) ) - return chart From 3d53d520c1c5d97ff7302ab2d3413160bcadbed7 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Tue, 18 Feb 2025 10:37:58 +0530 Subject: [PATCH 21/29] removed unused parameter --- mesa/visualization/components/altair_components.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index a3e2b36713f..d4e308acd99 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -342,15 +342,15 @@ def _draw_grid(space, agent_portrayal): match space: case Grid(): - return _draw_discrete_grid(space, all_agent_data, agent_portrayal) + return _draw_discrete_grid(space, all_agent_data) case _Grid(): - return _draw_legacy_grid(space, all_agent_data, agent_portrayal) + return _draw_legacy_grid(space, all_agent_data) case HexGrid(): - return _draw_hex_grid(space, all_agent_data, agent_portrayal) + return _draw_hex_grid(space, all_agent_data) case NetworkGrid(): - return _draw_network_grid(space, all_agent_data, agent_portrayal) + return _draw_network_grid(space, all_agent_data) case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): - return _draw_continuous_space(space, all_agent_data, agent_portrayal) + return _draw_continuous_space(space, all_agent_data) case _: raise NotImplementedError(f"Unsupported space type: {type(space)}") From 329bd24d13f3dc9d9024e93542606a2d4b71ddf8 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Tue, 18 Feb 2025 19:55:57 +0530 Subject: [PATCH 22/29] removed generic_data_collection function --- .../components/altair_components.py | 191 ++++++++---------- 1 file changed, 88 insertions(+), 103 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index d4e308acd99..22858d25be2 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -249,73 +249,6 @@ def axial_to_pixel(q, r, size=1): return x, y -def get_agent_data(space, agent_portrayal): - """Generic method to extract agent data for visualization across all space types. - - Args: - space: Mesa space object - agent_portrayal: Function defining agent visualization properties - - Returns: - List of agent data dictionaries with coordinates - """ - all_agent_data = [] - - match space: - case Grid(): - # New DiscreteSpace or experimental cell space - for cell in space.all_cells: - for agent in cell.agents: - data = agent_portrayal(agent) - data.update({"x": cell.coordinate[0], "y": cell.coordinate[1]}) - all_agent_data.append(data) - - case _Grid(): - # Legacy Grid - for content, (x, y) in space.coord_iter(): - if not content: - continue - agents = [content] if not hasattr(content, "__iter__") else content - for agent in agents: - data = agent_portrayal(agent) - data.update({"x": x, "y": y}) - all_agent_data.append(data) - - case HexGrid(): - # Hex-based grid - for content, (q, r) in space.coord_iter(): - if content: - for agent in content: - data = agent_portrayal(agent) - data.update({"q": q, "r": r}) # Store axial coordinates - all_agent_data.append(data) - - case NetworkGrid(): - # Network grid - for node in space.G.nodes(): - agents = space.G.nodes[node].get("agent", []) - if not isinstance(agents, list): - agents = [agents] if agents else [] - - for agent in agents: - if agent: - data = agent_portrayal(agent) - data.update({"node": node}) # Store node information - all_agent_data.append(data) - - case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): - # Continuous space - for agent in space.agents: - data = agent_portrayal(agent) - data.update({"x": agent.pos[0], "y": agent.pos[1]}) - all_agent_data.append(data) - - case _: - raise NotImplementedError(f"Unsupported space type: {type(space)}") - - return all_agent_data - - def _draw_grid(space, agent_portrayal): """Create Altair visualization for any supported space type. @@ -334,29 +267,36 @@ def _draw_grid(space, agent_portrayal): and its agents. Returns a text chart "No agents" if there are no agents. """ - all_agent_data = get_agent_data(space, agent_portrayal) - - # Handle empty state - if not all_agent_data: + # Handle empty state first + if not space.agents: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + match space: case Grid(): - return _draw_discrete_grid(space, all_agent_data) + return _draw_discrete_grid(space, agent_portrayal) case _Grid(): - return _draw_legacy_grid(space, all_agent_data) + return _draw_legacy_grid(space, agent_portrayal) case HexGrid(): - return _draw_hex_grid(space, all_agent_data) + return _draw_hex_grid(space, agent_portrayal) case NetworkGrid(): - return _draw_network_grid(space, all_agent_data) + return _draw_network_grid(space, agent_portrayal) case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): - return _draw_continuous_space(space, all_agent_data) + return _draw_continuous_space(space, agent_portrayal) case _: raise NotImplementedError(f"Unsupported space type: {type(space)}") - -def _draw_discrete_grid(space, all_agent_data, agent_portrayal): +def _draw_discrete_grid(space, agent_portrayal): """Create Altair visualization for Discrete Grid.""" + all_agent_data = [] + for cell in space.all_cells: + for agent in cell.agents: + data = agent_portrayal(agent) + data.update({"x": cell.coordinate[0], "y": cell.coordinate[1]}) + all_agent_data.append(data) + + if not all_agent_data: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + invalid_tooltips = ["color", "size", "x", "y"] x_y_type = "ordinal" @@ -390,19 +330,28 @@ def _draw_discrete_grid(space, all_agent_data, agent_portrayal): chart = chart.mark_point(size=30000 / length**2, filled=True) chart = chart.encode( - x=alt.X( - "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) - ), - y=alt.Y( - "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) - ), + x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), + y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) ) return chart -def _draw_legacy_grid(space, all_agent_data, agent_portrayal): +def _draw_legacy_grid(space, agent_portrayal): """Create Altair visualization for Legacy Grid.""" + all_agent_data = [] + for content, (x, y) in space.coord_iter(): + if not content: + continue + agents = [content] if not hasattr(content, "__iter__") else content + for agent in agents: + data = agent_portrayal(agent) + data.update({"x": x, "y": y}) + all_agent_data.append(data) + + if not all_agent_data: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + invalid_tooltips = ["color", "size", "x", "y"] x_y_type = "ordinal" @@ -436,19 +385,26 @@ def _draw_legacy_grid(space, all_agent_data, agent_portrayal): chart = chart.mark_point(size=30000 / length**2, filled=True) chart = chart.encode( - x=alt.X( - "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) - ), - y=alt.Y( - "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) - ), + x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), + y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) ) - return chart + return chart -def _draw_hex_grid(space, all_agent_data, agent_portrayal): +def _draw_hex_grid(space, agent_portrayal): """Create Altair visualization for Hex Grid.""" + all_agent_data = [] + for content, (q, r) in space.coord_iter(): + if content: + for agent in content: + data = agent_portrayal(agent) + data.update({"q": q, "r": r}) + all_agent_data.append(data) + + if not all_agent_data: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + invalid_tooltips = ["color", "size", "x", "y", "q", "r"] x_y_type = "quantitative" @@ -476,9 +432,10 @@ def _draw_hex_grid(space, all_agent_data, agent_portrayal): x_scale = alt.Scale(domain=(-2 * x_padding, x_max + x_padding)) y_scale = alt.Scale(domain=(-2 * y_padding, y_max + y_padding)) + encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), + "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -504,8 +461,24 @@ def _draw_hex_grid(space, all_agent_data, agent_portrayal): return chart -def _draw_network_grid(space, all_agent_data, agent_portrayal): +def _draw_network_grid(space, agent_portrayal): """Create Altair visualization for Network Grid.""" + all_agent_data = [] + for node in space.G.nodes(): + agents = space.G.nodes[node].get("agent", []) + if not isinstance(agents, list): + agents = [agents] if agents else [] + + for agent in agents: + if agent: + data = agent_portrayal(agent) + pos = space.G.nodes[node].get("pos", (0, 0)) + data.update({"x": pos[0], "y": pos[1]}) + all_agent_data.append(data) + + if not all_agent_data: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + invalid_tooltips = ["color", "size", "x", "y", "node"] x_y_type = "quantitative" @@ -513,14 +486,14 @@ def _draw_network_grid(space, all_agent_data, agent_portrayal): positions = [space.G.nodes[node].get("pos", (0, 0)) for node in space.G.nodes()] x_values = [p[0] for p in positions] y_values = [p[1] for p in positions] - + # Add padding to the bounds padding = 0.1 # 10% padding x_min, x_max = min(x_values), max(x_values) y_min, y_max = min(y_values), max(y_values) x_range = x_max - x_min y_range = y_max - y_min - + x_scale = alt.Scale(domain=(x_min - padding * x_range, x_max + padding * x_range)) y_scale = alt.Scale(domain=(y_min - padding * y_range, y_max + padding * y_range)) @@ -547,22 +520,33 @@ def _draw_network_grid(space, all_agent_data, agent_portrayal): ) .mark_point(filled=True) .properties(width=280, height=280) + ) - return chart + return chart -def _draw_continuous_space(space, all_agent_data, agent_portrayal): +def _draw_continuous_space(space, agent_portrayal): """Create Altair visualization for Continuous Space.""" + all_agent_data = [] + for agent in space.agents: + data = agent_portrayal(agent) + data.update({"x": agent.pos[0], "y": agent.pos[1]}) + all_agent_data.append(data) + + if not all_agent_data: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + invalid_tooltips = ["color", "size", "x", "y"] x_y_type = "quantitative" x_scale = alt.Scale(domain=(0, space.width)) y_scale = alt.Scale(domain=(0, space.height)) + encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), + "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -585,4 +569,5 @@ def _draw_continuous_space(space, all_agent_data, agent_portrayal): .properties(width=280, height=280) ) + return chart From a29c557a41a1eb2a4d110029be9ad708c7c37916 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Feb 2025 14:26:14 +0000 Subject: [PATCH 23/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../components/altair_components.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 22858d25be2..b3d8e93af28 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -270,7 +270,7 @@ def _draw_grid(space, agent_portrayal): # Handle empty state first if not space.agents: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + match space: case Grid(): return _draw_discrete_grid(space, agent_portrayal) @@ -285,6 +285,7 @@ def _draw_grid(space, agent_portrayal): case _: raise NotImplementedError(f"Unsupported space type: {type(space)}") + def _draw_discrete_grid(space, agent_portrayal): """Create Altair visualization for Discrete Grid.""" all_agent_data = [] @@ -330,8 +331,12 @@ def _draw_discrete_grid(space, agent_portrayal): chart = chart.mark_point(size=30000 / length**2, filled=True) chart = chart.encode( - x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), - y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) + x=alt.X( + "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) + ), + y=alt.Y( + "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) + ), ) return chart @@ -385,11 +390,15 @@ def _draw_legacy_grid(space, agent_portrayal): chart = chart.mark_point(size=30000 / length**2, filled=True) chart = chart.encode( - x=alt.X("x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1))), - y=alt.Y("y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1))) + x=alt.X( + "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) + ), + y=alt.Y( + "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) + ), ) - return chart + return chart def _draw_hex_grid(space, agent_portrayal): @@ -432,10 +441,9 @@ def _draw_hex_grid(space, agent_portrayal): x_scale = alt.Scale(domain=(-2 * x_padding, x_max + x_padding)) y_scale = alt.Scale(domain=(-2 * y_padding, y_max + y_padding)) - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), + "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -468,7 +476,7 @@ def _draw_network_grid(space, agent_portrayal): agents = space.G.nodes[node].get("agent", []) if not isinstance(agents, list): agents = [agents] if agents else [] - + for agent in agents: if agent: data = agent_portrayal(agent) @@ -486,14 +494,14 @@ def _draw_network_grid(space, agent_portrayal): positions = [space.G.nodes[node].get("pos", (0, 0)) for node in space.G.nodes()] x_values = [p[0] for p in positions] y_values = [p[1] for p in positions] - + # Add padding to the bounds padding = 0.1 # 10% padding x_min, x_max = min(x_values), max(x_values) y_min, y_max = min(y_values), max(y_values) x_range = x_max - x_min y_range = y_max - y_min - + x_scale = alt.Scale(domain=(x_min - padding * x_range, x_max + padding * x_range)) y_scale = alt.Scale(domain=(y_min - padding * y_range, y_max + padding * y_range)) @@ -520,10 +528,9 @@ def _draw_network_grid(space, agent_portrayal): ) .mark_point(filled=True) .properties(width=280, height=280) - ) - return chart + return chart def _draw_continuous_space(space, agent_portrayal): @@ -543,10 +550,9 @@ def _draw_continuous_space(space, agent_portrayal): x_scale = alt.Scale(domain=(0, space.width)) y_scale = alt.Scale(domain=(0, space.height)) - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type,scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type,scale=y_scale), + "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), + "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -569,5 +575,4 @@ def _draw_continuous_space(space, agent_portrayal): .properties(width=280, height=280) ) - return chart From 0631373985a44d806bb26196c0a068d1235268ce Mon Sep 17 00:00:00 2001 From: nissu99 Date: Sat, 1 Mar 2025 20:37:47 +0530 Subject: [PATCH 24/29] fixed altiar plotting --- .../components/altair_components.py | 320 +++++++++++------- 1 file changed, 192 insertions(+), 128 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index b3d8e93af28..d0219982061 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -2,8 +2,10 @@ import contextlib import math +import itertools import warnings from collections.abc import Callable +from functools import lru_cache import solara @@ -17,7 +19,7 @@ from mesa.experimental.cell_space import Grid from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter - +import numpy as np def make_space_altair(*args, **kwargs): """Create an Altair chart component for visualizing model space (deprecated). @@ -66,7 +68,6 @@ def make_altair_space( function: A function that creates a SpaceAltair component """ if agent_portrayal is None: - def agent_portrayal(a): return {"id": a.unique_id} @@ -226,20 +227,19 @@ def SpaceAltair( A solara.FigureAltair instance, which is a Solara component that renders the Altair chart. - """ + # Force update on dependencies change update_counter.get() space = getattr(model, "grid", None) if space is None: - # Sometimes the space is defined as model.space instead of model.grid space = model.space chart = _draw_grid(space, agent_portrayal) - # Apply post-processing if provided if post_process is not None: chart = post_process(chart) - solara.FigureAltair(chart) + # Return the rendered chart + return solara.FigureAltair(chart) def axial_to_pixel(q, r, size=1): @@ -272,12 +272,12 @@ def _draw_grid(space, agent_portrayal): return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) match space: + case HexGrid(): + return _draw_hex_grid(space, agent_portrayal) case Grid(): return _draw_discrete_grid(space, agent_portrayal) case _Grid(): return _draw_legacy_grid(space, agent_portrayal) - case HexGrid(): - return _draw_hex_grid(space, agent_portrayal) case NetworkGrid(): return _draw_network_grid(space, agent_portrayal) case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): @@ -289,55 +289,53 @@ def _draw_grid(space, agent_portrayal): def _draw_discrete_grid(space, agent_portrayal): """Create Altair visualization for Discrete Grid.""" all_agent_data = [] + + # Collect agent data for cell in space.all_cells: for agent in cell.agents: data = agent_portrayal(agent) - data.update({"x": cell.coordinate[0], "y": cell.coordinate[1]}) + data.update({ + "x": float(cell.coordinate[0]), + "y": float(cell.coordinate[1]) + }) all_agent_data.append(data) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - invalid_tooltips = ["color", "size", "x", "y"] - x_y_type = "ordinal" + # Create base chart + base = alt.Chart(alt.Data(values=all_agent_data)).properties( + width=280, height=280 + ) - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type), - "y": alt.Y("y", axis=None, type=x_y_type), - "tooltip": [ - alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) - for key, value in all_agent_data[0].items() - if key not in invalid_tooltips - ], + # Configure encodings + encodings = { + "x": alt.X( + "x:Q", + scale=alt.Scale(domain=[0, space.width-1]), + axis=alt.Axis(grid=True) # Enable grid + ), + "y": alt.Y( + "y:Q", + scale=alt.Scale(domain=[0, space.height-1]), + axis=alt.Axis(grid=True) # Enable grid + ), } - has_color = "color" in all_agent_data[0] - if has_color: - encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] - if has_size: - encoding_dict["size"] = alt.Size("size", type="quantitative") + # Add color encoding if present + if "color" in all_agent_data[0]: + encodings["color"] = alt.Color("color:N") - chart = ( - alt.Chart( - alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) - ) - .mark_point(filled=True) - .properties(width=280, height=280) - ) - - if not has_size: - length = min(space.width, space.height) - chart = chart.mark_point(size=30000 / length**2, filled=True) + # Add size encoding if present + if "size" in all_agent_data: + encodings["size"] = alt.Size("size:Q") + else: + # Default size based on grid dimensions + point_size = 30000 / min(space.width, space.height)**2 + base = base.mark_point(size=point_size, filled=True) - chart = chart.encode( - x=alt.X( - "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) - ), - y=alt.Y( - "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) - ), - ) + # Create final chart with encodings + chart = base.encode(**encodings) return chart @@ -361,8 +359,8 @@ def _draw_legacy_grid(space, agent_portrayal): x_y_type = "ordinal" encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type), - "y": alt.Y("y", axis=None, type=x_y_type), + "x": alt.X("x", axis=alt.Axis(grid=True), type=x_y_type), # Enable grid + "y": alt.Y("y", axis=alt.Axis(grid=True), type=x_y_type), # Enable grid "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -373,7 +371,7 @@ def _draw_legacy_grid(space, agent_portrayal): has_color = "color" in all_agent_data[0] if has_color: encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] + has_size = "size" in all_agent_data if has_size: encoding_dict["size"] = alt.Size("size", type="quantitative") @@ -391,81 +389,147 @@ def _draw_legacy_grid(space, agent_portrayal): chart = chart.encode( x=alt.X( - "x", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) + "x", axis=alt.Axis(grid=True), type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) ), y=alt.Y( - "y", axis=None, type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) + "y", axis=alt.Axis(grid=True), type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) ), ) return chart -def _draw_hex_grid(space, agent_portrayal): - """Create Altair visualization for Hex Grid.""" - all_agent_data = [] - for content, (q, r) in space.coord_iter(): - if content: - for agent in content: - data = agent_portrayal(agent) - data.update({"q": q, "r": r}) - all_agent_data.append(data) +@lru_cache(maxsize=1024, typed=True) +def _get_hexmesh( + width: int, height: int, size: float = 1.0 +) -> list[tuple[float, float]]: + """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" + + # Helper function for getting the vertices of a hexagon given the center and size + def _get_hex_vertices( + center_x: float, center_y: float, size: float = 1.0 + ) -> list[tuple[float, float]]: + """Get vertices for a hexagon centered at (center_x, center_y).""" + vertices = [ + (center_x, center_y + size), # top + (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right + (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right + (center_x, center_y - size), # bottom + (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left + (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left + ] + return vertices + + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + hexagons = [] - if not all_agent_data: - return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + for row, col in itertools.product(range(height), range(width)): + # Calculate center position with offset for even rows + x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) + y = row * y_spacing + hexagons.append(_get_hex_vertices(x, y, size)) + + return hexagons - invalid_tooltips = ["color", "size", "x", "y", "q", "r"] - x_y_type = "quantitative" - # Parameters for hexagon grid +def _draw_hex_grid(space, agent_portrayal): + """Create Altair visualization for Hex Grid.""" size = 1.0 x_spacing = math.sqrt(3) * size y_spacing = 1.5 * size - # Calculate x, y coordinates from axial coordinates - for agent_data in all_agent_data: - q = agent_data.pop("q") - r = agent_data.pop("r") - x, y = axial_to_pixel(q, r) - agent_data["x"] = x - agent_data["y"] = y + # Get cached hex mesh + hexagons = _get_hexmesh(space.width, space.height, size) - # Calculate proper bounds that account for the full hexagon width and height + # Calculate bounds x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) y_max = space.height * y_spacing - - # Add padding that accounts for the hexagon points x_padding = size * math.sqrt(3) / 2 y_padding = size - x_scale = alt.Scale(domain=(-2 * x_padding, x_max + x_padding)) - y_scale = alt.Scale(domain=(-2 * y_padding, y_max + y_padding)) + # Prepare data for grid lines + hex_lines = [] + hex_centers = [] + + for idx, hexagon in enumerate(hexagons): + # Calculate center of this hexagon + x_center = sum(p[0] for p in hexagon) / 6 + y_center = sum(p[1] for p in hexagon) / 6 + + # Calculate row and column from index + row = idx // space.width + col = idx % space.width + + # Store center + hex_centers.append((col, row, x_center, y_center)) + + # Create line segments + for i in range(6): + x1, y1 = hexagon[i] + x2, y2 = hexagon[(i + 1) % 6] + hex_lines.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2}) + + # Create grid lines layer + grid_lines = alt.Chart(alt.Data(values=hex_lines)).mark_rule( + color='gray', + strokeWidth=1, + opacity=0.5 + ).encode( + x='x1:Q', + y='y1:Q', + x2='x2:Q', + y2='y2:Q' + ).properties( + width=280, + height=280 + ) - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), - "tooltip": [ - alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) - for key, value in all_agent_data[0].items() - if key not in invalid_tooltips - ], - } + # Create mapping from coordinate to center position + center_map = {(col, row): (x, y) for col, row, x, y in hex_centers} - has_color = "color" in all_agent_data[0] - if has_color: - encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] - if has_size: - encoding_dict["size"] = alt.Size("size", type="quantitative") + # Create agents layer + all_agent_data = [] + + for cell in space.all_cells: + for agent in cell.agents: + data = agent_portrayal(agent) + # Get hex center for this cell's coordinate + coord = cell.coordinate + if coord in center_map: + x, y = center_map[coord] + data.update({"x": x, "y": y}) + all_agent_data.append(data) - chart = ( - alt.Chart( - alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) - ) - .mark_point(filled=True) - .properties(width=280, height=280) + if not all_agent_data: + return grid_lines + + # Create agent points layer + agent_layer = alt.Chart( + alt.Data(values=all_agent_data) + ).mark_circle( + filled=True, + size=150 + ).encode( + x=alt.X('x:Q', scale=alt.Scale(domain=[-2 * x_padding, x_max + x_padding])), + y=alt.Y('y:Q', scale=alt.Scale(domain=[-2 * y_padding, y_max + y_padding])), + ).properties( + width=280, + height=280 ) + # Add color encoding if present + if all_agent_data and "color" in all_agent_data[0]: + agent_layer = agent_layer.encode(color=alt.Color("color:N")) + + if all_agent_data and "size" in all_agent_data[0]: + agent_layer = agent_layer.encode(size=alt.Size("size:Q")) + + chart = (grid_lines + agent_layer).resolve_scale( + x='shared', + y='shared' + ) + return chart @@ -498,7 +562,7 @@ def _draw_network_grid(space, agent_portrayal): # Add padding to the bounds padding = 0.1 # 10% padding x_min, x_max = min(x_values), max(x_values) - y_min, y_max = min(y_values), max(y_values) + y_min, y_max = min(y_values), max(y_values) x_range = x_max - x_min y_range = y_max - y_min @@ -506,8 +570,8 @@ def _draw_network_grid(space, agent_portrayal): y_scale = alt.Scale(domain=(y_min - padding * y_range, y_max + padding * y_range)) encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), + "x": alt.X("x", axis=alt.Axis(grid=True), type=x_y_type, scale=x_scale), + "y": alt.Y("y", axis=alt.Axis(grid=True), type=x_y_type, scale=y_scale), "tooltip": [ alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() @@ -518,7 +582,7 @@ def _draw_network_grid(space, agent_portrayal): has_color = "color" in all_agent_data[0] if has_color: encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] + has_size = "size" in all_agent_data[0] if has_size: encoding_dict["size"] = alt.Size("size", type="quantitative") @@ -536,43 +600,43 @@ def _draw_network_grid(space, agent_portrayal): def _draw_continuous_space(space, agent_portrayal): """Create Altair visualization for Continuous Space.""" all_agent_data = [] + for agent in space.agents: data = agent_portrayal(agent) - data.update({"x": agent.pos[0], "y": agent.pos[1]}) + data.update({ + "x": float(agent.pos[0]), + "y": float(agent.pos[1]) + }) all_agent_data.append(data) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - invalid_tooltips = ["color", "size", "x", "y"] - x_y_type = "quantitative" - - x_scale = alt.Scale(domain=(0, space.width)) - y_scale = alt.Scale(domain=(0, space.height)) + base = alt.Chart(alt.Data(values=all_agent_data)).properties( + width=280, height=280 + ) - encoding_dict = { - "x": alt.X("x", axis=None, type=x_y_type, scale=x_scale), - "y": alt.Y("y", axis=None, type=x_y_type, scale=y_scale), - "tooltip": [ - alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) - for key, value in all_agent_data[0].items() - if key not in invalid_tooltips - ], + encodings = { + "x": alt.X( + "x:Q", + scale=alt.Scale(domain=[0, space.width]), + axis=alt.Axis(grid=True) # Enable grid + ), + "y": alt.Y( + "y:Q", + scale=alt.Scale(domain=[0, space.height]), + axis=alt.Axis(grid=True) # Enable grid + ) } - has_color = "color" in all_agent_data[0] - if has_color: - encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] - if has_size: - encoding_dict["size"] = alt.Size("size", type="quantitative") - - chart = ( - alt.Chart( - alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) - ) - .mark_point(filled=True) - .properties(width=280, height=280) - ) + if "color" in all_agent_data[0]: + encodings["color"] = alt.Color("color:N") + + if "size" in all_agent_data: + encodings["size"] = alt.Size("size:Q") + else: + base = base.mark_point(size=100, filled=True) + chart = base.encode(**encodings) + return chart From 22044f6798bbeb4ef9e52555a5fb1c8f17d5b056 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Mar 2025 15:08:03 +0000 Subject: [PATCH 25/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../components/altair_components.py | 124 ++++++++---------- 1 file changed, 56 insertions(+), 68 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index d0219982061..2537f952a51 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,8 +1,8 @@ """Altair based solara components for visualization mesa spaces.""" import contextlib -import math import itertools +import math import warnings from collections.abc import Callable from functools import lru_cache @@ -16,10 +16,12 @@ with contextlib.suppress(ImportError): import altair as alt +import numpy as np + from mesa.experimental.cell_space import Grid from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter -import numpy as np + def make_space_altair(*args, **kwargs): """Create an Altair chart component for visualizing model space (deprecated). @@ -68,6 +70,7 @@ def make_altair_space( function: A function that creates a SpaceAltair component """ if agent_portrayal is None: + def agent_portrayal(a): return {"id": a.unique_id} @@ -289,36 +292,33 @@ def _draw_grid(space, agent_portrayal): def _draw_discrete_grid(space, agent_portrayal): """Create Altair visualization for Discrete Grid.""" all_agent_data = [] - + # Collect agent data for cell in space.all_cells: for agent in cell.agents: data = agent_portrayal(agent) - data.update({ - "x": float(cell.coordinate[0]), - "y": float(cell.coordinate[1]) - }) + data.update( + {"x": float(cell.coordinate[0]), "y": float(cell.coordinate[1])} + ) all_agent_data.append(data) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) # Create base chart - base = alt.Chart(alt.Data(values=all_agent_data)).properties( - width=280, height=280 - ) + base = alt.Chart(alt.Data(values=all_agent_data)).properties(width=280, height=280) # Configure encodings encodings = { "x": alt.X( "x:Q", - scale=alt.Scale(domain=[0, space.width-1]), - axis=alt.Axis(grid=True) # Enable grid + scale=alt.Scale(domain=[0, space.width - 1]), + axis=alt.Axis(grid=True), # Enable grid ), "y": alt.Y( - "y:Q", - scale=alt.Scale(domain=[0, space.height-1]), - axis=alt.Axis(grid=True) # Enable grid + "y:Q", + scale=alt.Scale(domain=[0, space.height - 1]), + axis=alt.Axis(grid=True), # Enable grid ), } @@ -326,12 +326,12 @@ def _draw_discrete_grid(space, agent_portrayal): if "color" in all_agent_data[0]: encodings["color"] = alt.Color("color:N") - # Add size encoding if present + # Add size encoding if present if "size" in all_agent_data: encodings["size"] = alt.Size("size:Q") else: # Default size based on grid dimensions - point_size = 30000 / min(space.width, space.height)**2 + point_size = 30000 / min(space.width, space.height) ** 2 base = base.mark_point(size=point_size, filled=True) # Create final chart with encodings @@ -389,10 +389,16 @@ def _draw_legacy_grid(space, agent_portrayal): chart = chart.encode( x=alt.X( - "x", axis=alt.Axis(grid=True), type=x_y_type, scale=alt.Scale(domain=(0, space.width - 1)) + "x", + axis=alt.Axis(grid=True), + type=x_y_type, + scale=alt.Scale(domain=(0, space.width - 1)), ), y=alt.Y( - "y", axis=alt.Axis(grid=True), type=x_y_type, scale=alt.Scale(domain=(0, space.height - 1)) + "y", + axis=alt.Axis(grid=True), + type=x_y_type, + scale=alt.Scale(domain=(0, space.height - 1)), ), ) @@ -451,19 +457,19 @@ def _draw_hex_grid(space, agent_portrayal): # Prepare data for grid lines hex_lines = [] hex_centers = [] - + for idx, hexagon in enumerate(hexagons): # Calculate center of this hexagon x_center = sum(p[0] for p in hexagon) / 6 y_center = sum(p[1] for p in hexagon) / 6 - + # Calculate row and column from index row = idx // space.width col = idx % space.width - + # Store center hex_centers.append((col, row, x_center, y_center)) - + # Create line segments for i in range(6): x1, y1 = hexagon[i] @@ -471,18 +477,11 @@ def _draw_hex_grid(space, agent_portrayal): hex_lines.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2}) # Create grid lines layer - grid_lines = alt.Chart(alt.Data(values=hex_lines)).mark_rule( - color='gray', - strokeWidth=1, - opacity=0.5 - ).encode( - x='x1:Q', - y='y1:Q', - x2='x2:Q', - y2='y2:Q' - ).properties( - width=280, - height=280 + grid_lines = ( + alt.Chart(alt.Data(values=hex_lines)) + .mark_rule(color="gray", strokeWidth=1, opacity=0.5) + .encode(x="x1:Q", y="y1:Q", x2="x2:Q", y2="y2:Q") + .properties(width=280, height=280) ) # Create mapping from coordinate to center position @@ -490,7 +489,7 @@ def _draw_hex_grid(space, agent_portrayal): # Create agents layer all_agent_data = [] - + for cell in space.all_cells: for agent in cell.agents: data = agent_portrayal(agent) @@ -505,17 +504,14 @@ def _draw_hex_grid(space, agent_portrayal): return grid_lines # Create agent points layer - agent_layer = alt.Chart( - alt.Data(values=all_agent_data) - ).mark_circle( - filled=True, - size=150 - ).encode( - x=alt.X('x:Q', scale=alt.Scale(domain=[-2 * x_padding, x_max + x_padding])), - y=alt.Y('y:Q', scale=alt.Scale(domain=[-2 * y_padding, y_max + y_padding])), - ).properties( - width=280, - height=280 + agent_layer = ( + alt.Chart(alt.Data(values=all_agent_data)) + .mark_circle(filled=True, size=150) + .encode( + x=alt.X("x:Q", scale=alt.Scale(domain=[-2 * x_padding, x_max + x_padding])), + y=alt.Y("y:Q", scale=alt.Scale(domain=[-2 * y_padding, y_max + y_padding])), + ) + .properties(width=280, height=280) ) # Add color encoding if present @@ -525,11 +521,8 @@ def _draw_hex_grid(space, agent_portrayal): if all_agent_data and "size" in all_agent_data[0]: agent_layer = agent_layer.encode(size=alt.Size("size:Q")) - chart = (grid_lines + agent_layer).resolve_scale( - x='shared', - y='shared' - ) - + chart = (grid_lines + agent_layer).resolve_scale(x="shared", y="shared") + return chart @@ -562,7 +555,7 @@ def _draw_network_grid(space, agent_portrayal): # Add padding to the bounds padding = 0.1 # 10% padding x_min, x_max = min(x_values), max(x_values) - y_min, y_max = min(y_values), max(y_values) + y_min, y_max = min(y_values), max(y_values) x_range = x_max - x_min y_range = y_max - y_min @@ -582,7 +575,7 @@ def _draw_network_grid(space, agent_portrayal): has_color = "color" in all_agent_data[0] if has_color: encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] + has_size = "size" in all_agent_data[0] if has_size: encoding_dict["size"] = alt.Size("size", type="quantitative") @@ -600,43 +593,38 @@ def _draw_network_grid(space, agent_portrayal): def _draw_continuous_space(space, agent_portrayal): """Create Altair visualization for Continuous Space.""" all_agent_data = [] - + for agent in space.agents: data = agent_portrayal(agent) - data.update({ - "x": float(agent.pos[0]), - "y": float(agent.pos[1]) - }) + data.update({"x": float(agent.pos[0]), "y": float(agent.pos[1])}) all_agent_data.append(data) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - base = alt.Chart(alt.Data(values=all_agent_data)).properties( - width=280, height=280 - ) + base = alt.Chart(alt.Data(values=all_agent_data)).properties(width=280, height=280) encodings = { "x": alt.X( "x:Q", scale=alt.Scale(domain=[0, space.width]), - axis=alt.Axis(grid=True) # Enable grid + axis=alt.Axis(grid=True), # Enable grid ), "y": alt.Y( "y:Q", - scale=alt.Scale(domain=[0, space.height]), - axis=alt.Axis(grid=True) # Enable grid - ) + scale=alt.Scale(domain=[0, space.height]), + axis=alt.Axis(grid=True), # Enable grid + ), } if "color" in all_agent_data[0]: encodings["color"] = alt.Color("color:N") - + if "size" in all_agent_data: encodings["size"] = alt.Size("size:Q") else: base = base.mark_point(size=100, filled=True) chart = base.encode(**encodings) - + return chart From 0f6238c90973beb3430a99cb21a2498ee0dddc66 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Mon, 10 Mar 2025 02:36:04 +0530 Subject: [PATCH 26/29] extraction logic changes --- .../components/altair_components.py | 467 +++++++++++------- 1 file changed, 292 insertions(+), 175 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index d0219982061..3b3e52553eb 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -7,11 +7,14 @@ from collections.abc import Callable from functools import lru_cache +import mesa.discrete_space.network +import mesa.visualization +from mesa.visualization.mpl_space_drawing import collect_agent_data,_get_hexmesh import solara import mesa import mesa.experimental -from mesa.experimental.cell_space import HexGrid +from mesa.space import HexSingleGrid,HexMultiGrid with contextlib.suppress(ImportError): import altair as alt @@ -20,6 +23,7 @@ from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter import numpy as np +import networkx as nx def make_space_altair(*args, **kwargs): """Create an Altair chart component for visualizing model space (deprecated). @@ -232,9 +236,11 @@ def SpaceAltair( update_counter.get() space = getattr(model, "grid", None) if space is None: + # Sometimes the space is defined as model.space instead of model.grid space = model.space chart = _draw_grid(space, agent_portrayal) + #Apply post_processing if provided if post_process is not None: chart = post_process(chart) @@ -272,13 +278,13 @@ def _draw_grid(space, agent_portrayal): return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) match space: - case HexGrid(): + case HexSingleGrid() | HexMultiGrid(): return _draw_hex_grid(space, agent_portrayal) case Grid(): return _draw_discrete_grid(space, agent_portrayal) case _Grid(): return _draw_legacy_grid(space, agent_portrayal) - case NetworkGrid(): + case NetworkGrid()| mesa.discrete_space.network.Network(): return _draw_network_grid(space, agent_portrayal) case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): return _draw_continuous_space(space, agent_portrayal) @@ -286,22 +292,35 @@ def _draw_grid(space, agent_portrayal): raise NotImplementedError(f"Unsupported space type: {type(space)}") + def _draw_discrete_grid(space, agent_portrayal): """Create Altair visualization for Discrete Grid.""" - all_agent_data = [] - # Collect agent data - for cell in space.all_cells: - for agent in cell.agents: - data = agent_portrayal(agent) - data.update({ - "x": float(cell.coordinate[0]), - "y": float(cell.coordinate[1]) - }) - all_agent_data.append(data) - - if not all_agent_data: + # Get agent data using the collect_agent_data helper function + raw_data = collect_agent_data(space, agent_portrayal) + + # Early exit if no agents + if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + + # Convert raw_data (dict of arrays) to Altair format (list of dicts) + all_agent_data = [] + for i in range(len(raw_data["loc"])): + agent_dict = { + "x": float(raw_data["loc"][i][0]), + "y": float(raw_data["loc"][i][1]), + "color": raw_data["c"][i], + "size": raw_data["s"][i] + } + # Add other properties if they exist + if len(raw_data["alpha"]) > i: + agent_dict["alpha"] = raw_data["alpha"][i] + if len(raw_data["edgecolors"]) > i: + agent_dict["edgecolor"] = raw_data["edgecolors"][i] + if len(raw_data["linewidths"]) > i: + agent_dict["linewidth"] = raw_data["linewidths"][i] + + all_agent_data.append(agent_dict) # Create base chart base = alt.Chart(alt.Data(values=all_agent_data)).properties( @@ -327,15 +346,13 @@ def _draw_discrete_grid(space, agent_portrayal): encodings["color"] = alt.Color("color:N") # Add size encoding if present - if "size" in all_agent_data: + if "size" in all_agent_data[0]: encodings["size"] = alt.Size("size:Q") + chart = base.mark_point(filled=True).encode(**encodings) else: # Default size based on grid dimensions point_size = 30000 / min(space.width, space.height)**2 - base = base.mark_point(size=point_size, filled=True) - - # Create final chart with encodings - chart = base.encode(**encodings) + chart = base.mark_point(size=point_size, filled=True).encode(**encodings) return chart @@ -343,14 +360,30 @@ def _draw_discrete_grid(space, agent_portrayal): def _draw_legacy_grid(space, agent_portrayal): """Create Altair visualization for Legacy Grid.""" all_agent_data = [] - for content, (x, y) in space.coord_iter(): - if not content: - continue - agents = [content] if not hasattr(content, "__iter__") else content - for agent in agents: - data = agent_portrayal(agent) - data.update({"x": x, "y": y}) - all_agent_data.append(data) + raw_data = collect_agent_data(space, agent_portrayal) + + # Early exit if no agents + if len(raw_data["loc"]) == 0: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + + # Convert raw_data (dict of arrays) to Altair format (list of dicts) + all_agent_data = [] + for i in range(len(raw_data["loc"])): + agent_dict = { + "x": float(raw_data["loc"][i][0]), + "y": float(raw_data["loc"][i][1]), + "color": raw_data["c"][i], + "size": raw_data["s"][i] + } + # Add other properties if they exist + if len(raw_data["alpha"]) > i: + agent_dict["alpha"] = raw_data["alpha"][i] + if len(raw_data["edgecolors"]) > i: + agent_dict["edgecolor"] = raw_data["edgecolors"][i] + if len(raw_data["linewidths"]) > i: + agent_dict["linewidth"] = raw_data["linewidths"][i] + + all_agent_data.append(agent_dict) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) @@ -399,76 +432,57 @@ def _draw_legacy_grid(space, agent_portrayal): return chart -@lru_cache(maxsize=1024, typed=True) -def _get_hexmesh( - width: int, height: int, size: float = 1.0 -) -> list[tuple[float, float]]: - """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" - - # Helper function for getting the vertices of a hexagon given the center and size - def _get_hex_vertices( - center_x: float, center_y: float, size: float = 1.0 - ) -> list[tuple[float, float]]: - """Get vertices for a hexagon centered at (center_x, center_y).""" - vertices = [ - (center_x, center_y + size), # top - (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right - (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right - (center_x, center_y - size), # bottom - (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left - (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left - ] - return vertices - - x_spacing = np.sqrt(3) * size - y_spacing = 1.5 * size - hexagons = [] - - for row, col in itertools.product(range(height), range(width)): - # Calculate center position with offset for even rows - x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) - y = row * y_spacing - hexagons.append(_get_hex_vertices(x, y, size)) - - return hexagons - - def _draw_hex_grid(space, agent_portrayal): """Create Altair visualization for Hex Grid.""" + # Parameters for hexagon grid size = 1.0 - x_spacing = math.sqrt(3) * size + x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size - # Get cached hex mesh - hexagons = _get_hexmesh(space.width, space.height, size) - + # Get color and size defaults + s_default = (180 / max(space.width, space.height)) ** 2 + + # Get agent data using the collect_agent_data helper function + raw_data = collect_agent_data(space, agent_portrayal) + + # Early exit if no agents + if len(raw_data["loc"]) == 0: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + + # Transform hex coordinates to pixel coordinates + loc = raw_data["loc"].astype(float) + if loc.size > 0: + # Apply the hex grid transformation for agent positions + loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] % 2) * (x_spacing / 2)) + loc[:, 1] = loc[:, 1] * y_spacing + + # Convert raw_data to Altair format + all_agent_data = [] + for i in range(len(raw_data["loc"])): + agent_dict = { + "x": float(loc[i][0]), # Use transformed coordinates + "y": float(loc[i][1]), # Use transformed coordinates + "color": raw_data["c"][i], + "size": raw_data["s"][i] + } + # Add other properties if they exist + if len(raw_data["alpha"]) > i: + agent_dict["alpha"] = raw_data["alpha"][i] + if len(raw_data["edgecolors"]) > i: + agent_dict["edgecolor"] = raw_data["edgecolors"][i] + if len(raw_data["linewidths"]) > i: + agent_dict["linewidth"] = raw_data["linewidths"][i] + + all_agent_data.append(agent_dict) + # Calculate bounds x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) y_max = space.height * y_spacing x_padding = size * math.sqrt(3) / 2 y_padding = size - # Prepare data for grid lines - hex_lines = [] - hex_centers = [] - - for idx, hexagon in enumerate(hexagons): - # Calculate center of this hexagon - x_center = sum(p[0] for p in hexagon) / 6 - y_center = sum(p[1] for p in hexagon) / 6 - - # Calculate row and column from index - row = idx // space.width - col = idx % space.width - - # Store center - hex_centers.append((col, row, x_center, y_center)) - - # Create line segments - for i in range(6): - x1, y1 = hexagon[i] - x2, y2 = hexagon[(i + 1) % 6] - hex_lines.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2}) + # Get hex grid lines using our new function + hex_lines = _get_hexmesh_altair(space.width, space.height, size) # Create grid lines layer grid_lines = alt.Chart(alt.Data(values=hex_lines)).mark_rule( @@ -485,22 +499,6 @@ def _draw_hex_grid(space, agent_portrayal): height=280 ) - # Create mapping from coordinate to center position - center_map = {(col, row): (x, y) for col, row, x, y in hex_centers} - - # Create agents layer - all_agent_data = [] - - for cell in space.all_cells: - for agent in cell.agents: - data = agent_portrayal(agent) - # Get hex center for this cell's coordinate - coord = cell.coordinate - if coord in center_map: - x, y = center_map[coord] - data.update({"x": x, "y": y}) - all_agent_data.append(data) - if not all_agent_data: return grid_lines @@ -508,11 +506,10 @@ def _draw_hex_grid(space, agent_portrayal): agent_layer = alt.Chart( alt.Data(values=all_agent_data) ).mark_circle( - filled=True, - size=150 + filled=True ).encode( - x=alt.X('x:Q', scale=alt.Scale(domain=[-2 * x_padding, x_max + x_padding])), - y=alt.Y('y:Q', scale=alt.Scale(domain=[-2 * y_padding, y_max + y_padding])), + x=alt.X('x:Q', scale=alt.Scale(domain=[-x_padding, x_max + x_padding]),axis=alt.Axis(grid=False)), + y=alt.Y('y:Q', scale=alt.Scale(domain=[-y_padding, y_max + y_padding]),axis=alt.Axis(grid=False)), ).properties( width=280, height=280 @@ -522,9 +519,13 @@ def _draw_hex_grid(space, agent_portrayal): if all_agent_data and "color" in all_agent_data[0]: agent_layer = agent_layer.encode(color=alt.Color("color:N")) + # Add size encoding if present if all_agent_data and "size" in all_agent_data[0]: agent_layer = agent_layer.encode(size=alt.Size("size:Q")) + else: + agent_layer = agent_layer.mark_circle(filled=True, size=s_default) + # Layer grid and agents together chart = (grid_lines + agent_layer).resolve_scale( x='shared', y='shared' @@ -533,82 +534,157 @@ def _draw_hex_grid(space, agent_portrayal): return chart -def _draw_network_grid(space, agent_portrayal): - """Create Altair visualization for Network Grid.""" - all_agent_data = [] - for node in space.G.nodes(): - agents = space.G.nodes[node].get("agent", []) - if not isinstance(agents, list): - agents = [agents] if agents else [] - - for agent in agents: - if agent: - data = agent_portrayal(agent) - pos = space.G.nodes[node].get("pos", (0, 0)) - data.update({"x": pos[0], "y": pos[1]}) - all_agent_data.append(data) - - if not all_agent_data: +def _draw_network_grid( + space: NetworkGrid | mesa.discrete_space.network.Network, + agent_portrayal: Callable, + draw_grid: bool = True, + layout_alg=nx.spring_layout, + layout_kwargs=None, + **kwargs, +): + """Create Altair visualization for Network Grid. + + Args: + space: The network space to visualize + agent_portrayal: A callable that defines how agents are portrayed + draw_grid: Whether to draw the network edges + layout_alg: A NetworkX layout algorithm to position nodes + layout_kwargs: Arguments to pass to the layout algorithm + """ + if layout_kwargs is None: + layout_kwargs = {"seed": 0} + + # Get the graph and calculate positions using layout algorithm + graph = space.G + pos = layout_alg(graph, **layout_kwargs) + + # Calculate bounds with padding + x_values = [p[0] for p in pos.values()] + y_values = [p[1] for p in pos.values()] + xmin, xmax = min(x_values), max(x_values) + ymin, ymax = min(y_values), max(y_values) + + width = xmax - xmin + height = ymax - ymin + x_padding = width / 20 + y_padding = height / 20 + + # Gather agent data using positions from layout algorithm + s_default = (180 / max(width, height)) ** 2 + raw_data = collect_agent_data(space, agent_portrayal) + + # Early exit if no agents + if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - - invalid_tooltips = ["color", "size", "x", "y", "node"] - x_y_type = "quantitative" - - # Get x, y coordinates and determine bounds - positions = [space.G.nodes[node].get("pos", (0, 0)) for node in space.G.nodes()] - x_values = [p[0] for p in positions] - y_values = [p[1] for p in positions] - - # Add padding to the bounds - padding = 0.1 # 10% padding - x_min, x_max = min(x_values), max(x_values) - y_min, y_max = min(y_values), max(y_values) - x_range = x_max - x_min - y_range = y_max - y_min - - x_scale = alt.Scale(domain=(x_min - padding * x_range, x_max + padding * x_range)) - y_scale = alt.Scale(domain=(y_min - padding * y_range, y_max + padding * y_range)) - - encoding_dict = { - "x": alt.X("x", axis=alt.Axis(grid=True), type=x_y_type, scale=x_scale), - "y": alt.Y("y", axis=alt.Axis(grid=True), type=x_y_type, scale=y_scale), - "tooltip": [ - alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) - for key, value in all_agent_data[0].items() - if key not in invalid_tooltips - ], - } - - has_color = "color" in all_agent_data[0] - if has_color: - encoding_dict["color"] = alt.Color("color", type="nominal") - has_size = "size" in all_agent_data[0] - if has_size: - encoding_dict["size"] = alt.Size("size", type="quantitative") - - chart = ( - alt.Chart( - alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) - ) - .mark_point(filled=True) - .properties(width=280, height=280) + + # Map agent positions to layout positions + loc = raw_data["loc"] + positions = np.array([pos[node_id] for node_id in loc]) + + # Create agent data for Altair + all_agent_data = [] + for i in range(len(loc)): + agent_dict = { + "x": float(positions[i][0]), + "y": float(positions[i][1]), + "color": raw_data["c"][i], + "size": raw_data["s"][i], + "node_id": int(loc[i]) # Keep node ID for reference + } + # Add other properties if they exist + if len(raw_data["alpha"]) > i: + agent_dict["alpha"] = raw_data["alpha"][i] + if len(raw_data["edgecolors"]) > i: + agent_dict["edgecolor"] = raw_data["edgecolors"][i] + if len(raw_data["linewidths"]) > i: + agent_dict["linewidth"] = raw_data["linewidths"][i] + + all_agent_data.append(agent_dict) + + # Create edge data for drawing network connections + edge_data = [] + if draw_grid: + for u, v in graph.edges(): + edge_data.append({ + "x1": pos[u][0], + "y1": pos[u][1], + "x2": pos[v][0], + "y2": pos[v][1] + }) + + # Create base chart for agents + agent_chart = alt.Chart( + alt.Data(values=all_agent_data) + ).mark_circle( + filled=True + ).encode( + x=alt.X('x:Q', scale=alt.Scale(domain=[xmin - x_padding, xmax + x_padding]), axis=alt.Axis(grid=False)), + y=alt.Y('y:Q', scale=alt.Scale(domain=[ymin - y_padding, ymax + y_padding]),axis=alt.Axis(grid=False)), + ).properties( + width=280, + height=280 ) - - return chart + + # Add color and size encodings if present + if all_agent_data: + if "color" in all_agent_data[0]: + agent_chart = agent_chart.encode(color=alt.Color("color:N")) + + if "size" in all_agent_data[0]: + agent_chart = agent_chart.encode(size=alt.Size("size:Q")) + else: + agent_chart = agent_chart.mark_circle(filled=True, size=s_default) + + # Create edge chart + if draw_grid and edge_data: + edge_chart = alt.Chart( + alt.Data(values=edge_data) + ).mark_rule( + color='gray', + strokeDash=[5, 5], # Equivalent to "--" style in matplotlib + opacity=0.5, + strokeWidth=1 + ).encode( + x="x1:Q", + y="y1:Q", + x2="x2:Q", + y2="y2:Q" + ) + + # Combine edge and agent charts + return alt.layer(edge_chart, agent_chart) + + return agent_chart def _draw_continuous_space(space, agent_portrayal): """Create Altair visualization for Continuous Space.""" all_agent_data = [] + # Get agent data using the collect_agent_data helper function + raw_data = collect_agent_data(space, agent_portrayal) - for agent in space.agents: - data = agent_portrayal(agent) - data.update({ - "x": float(agent.pos[0]), - "y": float(agent.pos[1]) - }) - all_agent_data.append(data) - + # Early exit if no agents + if len(raw_data["loc"]) == 0: + return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) + + # Convert raw_data (dict of arrays) to Altair format (list of dicts) + all_agent_data = [] + for i in range(len(raw_data["loc"])): + agent_dict = { + "x": float(raw_data["loc"][i][0]), + "y": float(raw_data["loc"][i][1]), + "color": raw_data["c"][i], + "size": raw_data["s"][i] + } + # Add other properties if they exist + if len(raw_data["alpha"]) > i: + agent_dict["alpha"] = raw_data["alpha"][i] + if len(raw_data["edgecolors"]) > i: + agent_dict["edgecolor"] = raw_data["edgecolors"][i] + if len(raw_data["linewidths"]) > i: + agent_dict["linewidth"] = raw_data["linewidths"][i] + + all_agent_data.append(agent_dict) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) @@ -640,3 +716,44 @@ def _draw_continuous_space(space, agent_portrayal): chart = base.encode(**encodings) return chart + +@lru_cache(maxsize=1024, typed=True) +def _get_hexmesh_altair(width: int, height: int, size: float = 1.0) -> list[dict]: + """Generate hexagon vertices for the mesh in altair format.""" + + # Parameters for hexagon grid + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + + hex_lines = [] + + # For flat-topped hexagons (note the orientation) + vertices_offsets = [ + (0, -size), # top + (0.5 * np.sqrt(3) * size, -0.5 * size), # top right + (0.5 * np.sqrt(3) * size, 0.5 * size), # bottom right + (0, size), # bottom + (-0.5 * np.sqrt(3) * size, 0.5 * size), # bottom left + (-0.5 * np.sqrt(3) * size, -0.5 * size) # top left + ] + + for row in range(height): + for col in range(width): + # Calculate center position with offset for odd rows + x_center = col * x_spacing + if row % 2 == 1: # Odd rows are offset + x_center += x_spacing / 2 + y_center = row * y_spacing + + # Calculate vertices for this hexagon + vertices = [] + for dx, dy in vertices_offsets: + vertices.append((x_center + dx, y_center + dy)) + + # Create line segments for the hexagon + for i in range(6): + x1, y1 = vertices[i] + x2, y2 = vertices[(i+1) % 6] + hex_lines.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2}) + + return hex_lines From ec7dc7090fa1b2b7fff1c2259d843006d82918ac Mon Sep 17 00:00:00 2001 From: nissu99 Date: Mon, 10 Mar 2025 02:43:11 +0530 Subject: [PATCH 27/29] small --- mesa/visualization/components/altair_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 96e02c4ecac..dc044987162 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -24,7 +24,6 @@ from mesa.experimental.cell_space import Grid from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter -import numpy as np import networkx as nx def make_space_altair(*args, **kwargs): From 348939a752248ddec244064c669134287e7a3ce0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Mar 2025 21:13:24 +0000 Subject: [PATCH 28/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../components/altair_components.py | 209 +++++++++--------- 1 file changed, 105 insertions(+), 104 deletions(-) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index dc044987162..26f9fe766af 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,30 +1,30 @@ """Altair based solara components for visualization mesa spaces.""" import contextlib -import itertools import math import warnings from collections.abc import Callable from functools import lru_cache -import mesa.discrete_space.network -import mesa.visualization -from mesa.visualization.mpl_space_drawing import collect_agent_data,_get_hexmesh import solara import mesa +import mesa.discrete_space.network import mesa.experimental -from mesa.space import HexSingleGrid,HexMultiGrid +import mesa.visualization +from mesa.space import HexMultiGrid, HexSingleGrid +from mesa.visualization.mpl_space_drawing import collect_agent_data with contextlib.suppress(ImportError): import altair as alt +import networkx as nx import numpy as np from mesa.experimental.cell_space import Grid from mesa.space import ContinuousSpace, NetworkGrid, _Grid from mesa.visualization.utils import update_counter -import networkx as nx + def make_space_altair(*args, **kwargs): """Create an Altair chart component for visualizing model space (deprecated). @@ -238,11 +238,11 @@ def SpaceAltair( update_counter.get() space = getattr(model, "grid", None) if space is None: - # Sometimes the space is defined as model.space instead of model.grid + # Sometimes the space is defined as model.space instead of model.grid space = model.space chart = _draw_grid(space, agent_portrayal) - #Apply post_processing if provided + # Apply post_processing if provided if post_process is not None: chart = post_process(chart) @@ -286,7 +286,7 @@ def _draw_grid(space, agent_portrayal): return _draw_discrete_grid(space, agent_portrayal) case _Grid(): return _draw_legacy_grid(space, agent_portrayal) - case NetworkGrid()| mesa.discrete_space.network.Network(): + case NetworkGrid() | mesa.discrete_space.network.Network(): return _draw_network_grid(space, agent_portrayal) case ContinuousSpace() | mesa.experimental.continuous_space.ContinuousSpace(): return _draw_continuous_space(space, agent_portrayal) @@ -294,17 +294,15 @@ def _draw_grid(space, agent_portrayal): raise NotImplementedError(f"Unsupported space type: {type(space)}") - def _draw_discrete_grid(space, agent_portrayal): """Create Altair visualization for Discrete Grid.""" - # Get agent data using the collect_agent_data helper function raw_data = collect_agent_data(space, agent_portrayal) - + # Early exit if no agents if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + # Convert raw_data (dict of arrays) to Altair format (list of dicts) all_agent_data = [] for i in range(len(raw_data["loc"])): @@ -312,7 +310,7 @@ def _draw_discrete_grid(space, agent_portrayal): "x": float(raw_data["loc"][i][0]), "y": float(raw_data["loc"][i][1]), "color": raw_data["c"][i], - "size": raw_data["s"][i] + "size": raw_data["s"][i], } # Add other properties if they exist if len(raw_data["alpha"]) > i: @@ -321,7 +319,7 @@ def _draw_discrete_grid(space, agent_portrayal): agent_dict["edgecolor"] = raw_data["edgecolors"][i] if len(raw_data["linewidths"]) > i: agent_dict["linewidth"] = raw_data["linewidths"][i] - + all_agent_data.append(agent_dict) # Create base chart @@ -345,13 +343,13 @@ def _draw_discrete_grid(space, agent_portrayal): if "color" in all_agent_data[0]: encodings["color"] = alt.Color("color:N") - # Add size encoding if present + # Add size encoding if present if "size" in all_agent_data[0]: encodings["size"] = alt.Size("size:Q") chart = base.mark_point(filled=True).encode(**encodings) else: # Default size based on grid dimensions - point_size = 30000 / min(space.width, space.height)**2 + point_size = 30000 / min(space.width, space.height) ** 2 chart = base.mark_point(size=point_size, filled=True).encode(**encodings) return chart @@ -361,11 +359,11 @@ def _draw_legacy_grid(space, agent_portrayal): """Create Altair visualization for Legacy Grid.""" all_agent_data = [] raw_data = collect_agent_data(space, agent_portrayal) - + # Early exit if no agents if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + # Convert raw_data (dict of arrays) to Altair format (list of dicts) all_agent_data = [] for i in range(len(raw_data["loc"])): @@ -373,7 +371,7 @@ def _draw_legacy_grid(space, agent_portrayal): "x": float(raw_data["loc"][i][0]), "y": float(raw_data["loc"][i][1]), "color": raw_data["c"][i], - "size": raw_data["s"][i] + "size": raw_data["s"][i], } # Add other properties if they exist if len(raw_data["alpha"]) > i: @@ -382,7 +380,7 @@ def _draw_legacy_grid(space, agent_portrayal): agent_dict["edgecolor"] = raw_data["edgecolors"][i] if len(raw_data["linewidths"]) > i: agent_dict["linewidth"] = raw_data["linewidths"][i] - + all_agent_data.append(agent_dict) if not all_agent_data: @@ -447,21 +445,21 @@ def _draw_hex_grid(space, agent_portrayal): # Get color and size defaults s_default = (180 / max(space.width, space.height)) ** 2 - + # Get agent data using the collect_agent_data helper function raw_data = collect_agent_data(space, agent_portrayal) - + # Early exit if no agents if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + # Transform hex coordinates to pixel coordinates loc = raw_data["loc"].astype(float) if loc.size > 0: # Apply the hex grid transformation for agent positions loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] % 2) * (x_spacing / 2)) loc[:, 1] = loc[:, 1] * y_spacing - + # Convert raw_data to Altair format all_agent_data = [] for i in range(len(raw_data["loc"])): @@ -469,7 +467,7 @@ def _draw_hex_grid(space, agent_portrayal): "x": float(loc[i][0]), # Use transformed coordinates "y": float(loc[i][1]), # Use transformed coordinates "color": raw_data["c"][i], - "size": raw_data["s"][i] + "size": raw_data["s"][i], } # Add other properties if they exist if len(raw_data["alpha"]) > i: @@ -478,9 +476,9 @@ def _draw_hex_grid(space, agent_portrayal): agent_dict["edgecolor"] = raw_data["edgecolors"][i] if len(raw_data["linewidths"]) > i: agent_dict["linewidth"] = raw_data["linewidths"][i] - + all_agent_data.append(agent_dict) - + # Calculate bounds x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) y_max = space.height * y_spacing @@ -502,16 +500,22 @@ def _draw_hex_grid(space, agent_portrayal): return grid_lines # Create agent points layer - agent_layer = alt.Chart( - alt.Data(values=all_agent_data) - ).mark_circle( - filled=True - ).encode( - x=alt.X('x:Q', scale=alt.Scale(domain=[-x_padding, x_max + x_padding]),axis=alt.Axis(grid=False)), - y=alt.Y('y:Q', scale=alt.Scale(domain=[-y_padding, y_max + y_padding]),axis=alt.Axis(grid=False)), - ).properties( - width=280, - height=280 + agent_layer = ( + alt.Chart(alt.Data(values=all_agent_data)) + .mark_circle(filled=True) + .encode( + x=alt.X( + "x:Q", + scale=alt.Scale(domain=[-x_padding, x_max + x_padding]), + axis=alt.Axis(grid=False), + ), + y=alt.Y( + "y:Q", + scale=alt.Scale(domain=[-y_padding, y_max + y_padding]), + axis=alt.Axis(grid=False), + ), + ) + .properties(width=280, height=280) ) # Add color encoding if present @@ -525,11 +529,8 @@ def _draw_hex_grid(space, agent_portrayal): agent_layer = agent_layer.mark_circle(filled=True, size=s_default) # Layer grid and agents together - chart = (grid_lines + agent_layer).resolve_scale( - x='shared', - y='shared' - ) - + chart = (grid_lines + agent_layer).resolve_scale(x="shared", y="shared") + return chart @@ -542,7 +543,7 @@ def _draw_network_grid( **kwargs, ): """Create Altair visualization for Network Grid. - + Args: space: The network space to visualize agent_portrayal: A callable that defines how agents are portrayed @@ -552,34 +553,34 @@ def _draw_network_grid( """ if layout_kwargs is None: layout_kwargs = {"seed": 0} - + # Get the graph and calculate positions using layout algorithm graph = space.G pos = layout_alg(graph, **layout_kwargs) - + # Calculate bounds with padding x_values = [p[0] for p in pos.values()] y_values = [p[1] for p in pos.values()] xmin, xmax = min(x_values), max(x_values) ymin, ymax = min(y_values), max(y_values) - + width = xmax - xmin height = ymax - ymin x_padding = width / 20 y_padding = height / 20 - + # Gather agent data using positions from layout algorithm s_default = (180 / max(width, height)) ** 2 raw_data = collect_agent_data(space, agent_portrayal) - + # Early exit if no agents if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + # Map agent positions to layout positions loc = raw_data["loc"] positions = np.array([pos[node_id] for node_id in loc]) - + # Create agent data for Altair all_agent_data = [] for i in range(len(loc)): @@ -588,7 +589,7 @@ def _draw_network_grid( "y": float(positions[i][1]), "color": raw_data["c"][i], "size": raw_data["s"][i], - "node_id": int(loc[i]) # Keep node ID for reference + "node_id": int(loc[i]), # Keep node ID for reference } # Add other properties if they exist if len(raw_data["alpha"]) > i: @@ -597,62 +598,62 @@ def _draw_network_grid( agent_dict["edgecolor"] = raw_data["edgecolors"][i] if len(raw_data["linewidths"]) > i: agent_dict["linewidth"] = raw_data["linewidths"][i] - + all_agent_data.append(agent_dict) - + # Create edge data for drawing network connections edge_data = [] if draw_grid: for u, v in graph.edges(): - edge_data.append({ - "x1": pos[u][0], - "y1": pos[u][1], - "x2": pos[v][0], - "y2": pos[v][1] - }) - + edge_data.append( + {"x1": pos[u][0], "y1": pos[u][1], "x2": pos[v][0], "y2": pos[v][1]} + ) + # Create base chart for agents - agent_chart = alt.Chart( - alt.Data(values=all_agent_data) - ).mark_circle( - filled=True - ).encode( - x=alt.X('x:Q', scale=alt.Scale(domain=[xmin - x_padding, xmax + x_padding]), axis=alt.Axis(grid=False)), - y=alt.Y('y:Q', scale=alt.Scale(domain=[ymin - y_padding, ymax + y_padding]),axis=alt.Axis(grid=False)), - ).properties( - width=280, - height=280 + agent_chart = ( + alt.Chart(alt.Data(values=all_agent_data)) + .mark_circle(filled=True) + .encode( + x=alt.X( + "x:Q", + scale=alt.Scale(domain=[xmin - x_padding, xmax + x_padding]), + axis=alt.Axis(grid=False), + ), + y=alt.Y( + "y:Q", + scale=alt.Scale(domain=[ymin - y_padding, ymax + y_padding]), + axis=alt.Axis(grid=False), + ), + ) + .properties(width=280, height=280) ) - + # Add color and size encodings if present if all_agent_data: if "color" in all_agent_data[0]: agent_chart = agent_chart.encode(color=alt.Color("color:N")) - + if "size" in all_agent_data[0]: agent_chart = agent_chart.encode(size=alt.Size("size:Q")) else: agent_chart = agent_chart.mark_circle(filled=True, size=s_default) - + # Create edge chart if draw_grid and edge_data: - edge_chart = alt.Chart( - alt.Data(values=edge_data) - ).mark_rule( - color='gray', - strokeDash=[5, 5], # Equivalent to "--" style in matplotlib - opacity=0.5, - strokeWidth=1 - ).encode( - x="x1:Q", - y="y1:Q", - x2="x2:Q", - y2="y2:Q" + edge_chart = ( + alt.Chart(alt.Data(values=edge_data)) + .mark_rule( + color="gray", + strokeDash=[5, 5], # Equivalent to "--" style in matplotlib + opacity=0.5, + strokeWidth=1, + ) + .encode(x="x1:Q", y="y1:Q", x2="x2:Q", y2="y2:Q") ) - + # Combine edge and agent charts return alt.layer(edge_chart, agent_chart) - + return agent_chart @@ -661,11 +662,11 @@ def _draw_continuous_space(space, agent_portrayal): all_agent_data = [] # Get agent data using the collect_agent_data helper function raw_data = collect_agent_data(space, agent_portrayal) - + # Early exit if no agents if len(raw_data["loc"]) == 0: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) - + # Convert raw_data (dict of arrays) to Altair format (list of dicts) all_agent_data = [] for i in range(len(raw_data["loc"])): @@ -673,7 +674,7 @@ def _draw_continuous_space(space, agent_portrayal): "x": float(raw_data["loc"][i][0]), "y": float(raw_data["loc"][i][1]), "color": raw_data["c"][i], - "size": raw_data["s"][i] + "size": raw_data["s"][i], } # Add other properties if they exist if len(raw_data["alpha"]) > i: @@ -682,7 +683,7 @@ def _draw_continuous_space(space, agent_portrayal): agent_dict["edgecolor"] = raw_data["edgecolors"][i] if len(raw_data["linewidths"]) > i: agent_dict["linewidth"] = raw_data["linewidths"][i] - + all_agent_data.append(agent_dict) if not all_agent_data: return alt.Chart().mark_text(text="No agents").properties(width=280, height=280) @@ -714,26 +715,26 @@ def _draw_continuous_space(space, agent_portrayal): return chart + @lru_cache(maxsize=1024, typed=True) def _get_hexmesh_altair(width: int, height: int, size: float = 1.0) -> list[dict]: """Generate hexagon vertices for the mesh in altair format.""" - # Parameters for hexagon grid x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size - + hex_lines = [] - + # For flat-topped hexagons (note the orientation) vertices_offsets = [ - (0, -size), # top + (0, -size), # top (0.5 * np.sqrt(3) * size, -0.5 * size), # top right - (0.5 * np.sqrt(3) * size, 0.5 * size), # bottom right - (0, size), # bottom + (0.5 * np.sqrt(3) * size, 0.5 * size), # bottom right + (0, size), # bottom (-0.5 * np.sqrt(3) * size, 0.5 * size), # bottom left - (-0.5 * np.sqrt(3) * size, -0.5 * size) # top left + (-0.5 * np.sqrt(3) * size, -0.5 * size), # top left ] - + for row in range(height): for col in range(width): # Calculate center position with offset for odd rows @@ -741,16 +742,16 @@ def _get_hexmesh_altair(width: int, height: int, size: float = 1.0) -> list[dict if row % 2 == 1: # Odd rows are offset x_center += x_spacing / 2 y_center = row * y_spacing - + # Calculate vertices for this hexagon vertices = [] for dx, dy in vertices_offsets: vertices.append((x_center + dx, y_center + dy)) - + # Create line segments for the hexagon for i in range(6): x1, y1 = vertices[i] - x2, y2 = vertices[(i+1) % 6] + x2, y2 = vertices[(i + 1) % 6] hex_lines.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2}) - + return hex_lines From b0192a61cdf24376c827d5609f98fb3406340ce7 Mon Sep 17 00:00:00 2001 From: nissu99 Date: Mon, 10 Mar 2025 02:48:43 +0530 Subject: [PATCH 29/29] docstring --- mesa/visualization/components/altair_components.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index dc044987162..c23c858ab96 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -549,6 +549,7 @@ def _draw_network_grid( draw_grid: Whether to draw the network edges layout_alg: A NetworkX layout algorithm to position nodes layout_kwargs: Arguments to pass to the layout algorithm + **kwargs: Additional keyword arguments passed to the visualization """ if layout_kwargs is None: layout_kwargs = {"seed": 0}