diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 61ccf0d6648..2ad3b249fa0 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -3,10 +3,14 @@ import warnings import altair as alt +import numpy as np +import pandas as pd import solara +from matplotlib.colors import to_rgb +import mesa from mesa.discrete_space import DiscreteSpace, Grid -from mesa.space import ContinuousSpace, _Grid +from mesa.space import ContinuousSpace, PropertyLayer, _Grid from mesa.visualization.utils import update_counter @@ -20,13 +24,16 @@ def make_space_altair(*args, **kwargs): # noqa: D103 def make_altair_space( - agent_portrayal, propertylayer_portrayal, post_process, **space_drawing_kwargs + agent_portrayal, + propertylayer_portrayal=None, + post_process=None, + **space_drawing_kwargs, ): """Create an Altair-based space visualization component. Args: agent_portrayal: Function to portray agents. - propertylayer_portrayal: not yet implemented + propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks) space_drawing_kwargs : not yet implemented @@ -43,14 +50,23 @@ def agent_portrayal(a): return {"id": a.unique_id} def MakeSpaceAltair(model): - return SpaceAltair(model, agent_portrayal, post_process=post_process) + return SpaceAltair( + model, + agent_portrayal, + propertylayer_portrayal=propertylayer_portrayal, + post_process=post_process, + ) return MakeSpaceAltair @solara.component def SpaceAltair( - model, agent_portrayal, dependencies: list[any] | None = None, post_process=None + model, + agent_portrayal, + propertylayer_portrayal=None, + dependencies: list[any] | None = None, + post_process=None, ): """Create an Altair-based space visualization component. @@ -63,10 +79,11 @@ def SpaceAltair( # Sometimes the space is defined as model.space instead of model.grid space = model.space - chart = _draw_grid(space, agent_portrayal) + chart = _draw_grid(space, agent_portrayal, propertylayer_portrayal) # Apply post-processing if provided if post_process is not None: chart = post_process(chart) + solara.FigureAltair(chart) @@ -138,7 +155,7 @@ def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal): return all_agent_data -def _draw_grid(space, agent_portrayal): +def _draw_grid(space, agent_portrayal, propertylayer_portrayal): match space: case Grid(): all_agent_data = _get_agent_data_new_discrete_space(space, agent_portrayal) @@ -168,23 +185,266 @@ def _draw_grid(space, agent_portrayal): } has_color = "color" in all_agent_data[0] if has_color: - encoding_dict["color"] = alt.Color("color", type="nominal") + unique_colors = list({agent["color"] for agent in all_agent_data}) + encoding_dict["color"] = alt.Color( + "color:N", + scale=alt.Scale(domain=unique_colors, range=unique_colors), + ) has_size = "size" in all_agent_data[0] if has_size: encoding_dict["size"] = alt.Size("size", type="quantitative") - chart = ( + agent_chart = ( alt.Chart( alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) ) .mark_point(filled=True) - .properties(width=280, height=280) - # .configure_view(strokeOpacity=0) # hide grid/chart lines + .properties(width=300, height=300) ) - # This is the default value for the marker size, which auto-scales - # according to the grid area. + base_chart = None + cbar_chart = None + + # This is the default value for the marker size, which auto-scales according to the grid area. if not has_size: length = min(space.width, space.height) - chart = chart.mark_point(size=30000 / length**2, filled=True) + agent_chart = agent_chart.mark_point(size=30000 / length**2, filled=True) + + if propertylayer_portrayal is not None: + chart_width = agent_chart.properties().width + chart_height = agent_chart.properties().height + base_chart, cbar_chart = chart_property_layers( + space=space, + propertylayer_portrayal=propertylayer_portrayal, + chart_width=chart_width, + chart_height=chart_height, + ) + + base_chart = alt.layer(base_chart, agent_chart) + else: + base_chart = agent_chart + if cbar_chart is not None: + base_chart = alt.vconcat(base_chart, cbar_chart).configure_view(stroke=None) + return base_chart + + +def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_height): + """Creates Property Layers in the Altair Components. + + Args: + space: the ContinuousSpace instance + propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications + chart_width: width of the agent chart to maintain consistency with the property charts + chart_height: height of the agent chart to maintain consistency with the property charts + agent_chart: the agent chart to layer with the property layers on the grid + Returns: + Altair Chart + """ + try: + # old style spaces + property_layers = space.properties + except AttributeError: + # new style spaces + property_layers = space._mesa_property_layers + base = None + bar_chart = None + for layer_name, portrayal in propertylayer_portrayal.items(): + layer = property_layers.get(layer_name, None) + if not isinstance( + layer, + PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer, + ): + continue - return chart + data = layer.data.astype(float) if layer.data.dtype == bool else layer.data + + if (space.width, space.height) != data.shape: + warnings.warn( + f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).", + UserWarning, + stacklevel=2, + ) + alpha = portrayal.get("alpha", 1) + vmin = portrayal.get("vmin", np.min(data)) + vmax = portrayal.get("vmax", np.max(data)) + colorbar = portrayal.get("colorbar", True) + + # Prepare data for Altair (convert 2D array to a long-form DataFrame) + df = pd.DataFrame( + { + "x": np.repeat(np.arange(data.shape[0]), data.shape[1]), + "y": np.tile(np.arange(data.shape[1]), data.shape[0]), + "value": data.flatten(), + } + ) + + if "color" in portrayal: + # Create a function to map values to RGBA colors with proper opacity scaling + def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal): + """Maps data values to RGBA colors with opacity based on value magnitude. + + Args: + val: The data value to convert + vmin: The smallest value for which the color is displayed in the colorbar + vmax: The largest value for which the color is displayed in the colorbar + alpha: The opacity of the color + portrayal: The specifics of the current property layer in the iterative loop + + Returns: + String representation of RGBA color + """ + # Normalize value to range [0,1] and clamp + normalized = max(0, min((val - vmin) / (vmax - vmin), 1)) + + # Scale opacity by alpha parameter + opacity = normalized * alpha + + # Convert color to RGB components + rgb_color = to_rgb(portrayal["color"]) + r = int(rgb_color[0] * 255) + g = int(rgb_color[1] * 255) + b = int(rgb_color[2] * 255) + + return f"rgba({r}, {g}, {b}, {opacity:.2f})" + + # Apply color mapping to each value in the dataset + df["color"] = df["value"].apply(apply_rgba) + + # Create chart for the property layer + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("x:O", axis=None), + y=alt.Y("y:O", axis=None), + fill=alt.Fill("color:N", scale=None), + ) + .properties(width=chart_width, height=chart_height, title=layer_name) + ) + base = alt.layer(chart, base) if base is not None else chart + + # Add colorbar if specified in portrayal + if colorbar: + # Extract RGB components from base color + rgb_color = to_rgb(portrayal["color"]) + r_int = int(rgb_color[0] * 255) + g_int = int(rgb_color[1] * 255) + b_int = int(rgb_color[2] * 255) + + # Define gradient endpoints + min_color = f"rgba({r_int},{g_int},{b_int},0)" + max_color = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})" + + # Define colorbar dimensions + colorbar_height = 20 + colorbar_width = chart_width + + # Create dataframe for gradient visualization + df_gradient = pd.DataFrame({"x": [0, 1], "y": [0, 1]}) + + # Create evenly distributed tick values + axis_values = np.linspace(vmin, vmax, 11) + tick_positions = np.linspace(0, colorbar_width, 11) + + # Prepare data for axis and labels + axis_data = pd.DataFrame({"value": axis_values, "x": tick_positions}) + + # Create colorbar with linear gradient + colorbar_chart = ( + alt.Chart(df_gradient) + .mark_rect( + x=0, + y=0, + width=colorbar_width, + height=colorbar_height, + color=alt.Gradient( + gradient="linear", + stops=[ + alt.GradientStop(color=min_color, offset=0), + alt.GradientStop(color=max_color, offset=1), + ], + x1=0, + x2=1, # Horizontal gradient + y1=0, + y2=0, # Keep y constant + ), + ) + .encode( + x=alt.value(chart_width / 2), # Center colorbar + y=alt.value(0), + ) + .properties(width=colorbar_width, height=colorbar_height) + ) + + # Add tick marks to colorbar + axis_chart = ( + alt.Chart(axis_data) + .mark_tick(thickness=2, size=8) + .encode(x=alt.X("x:Q", axis=None), y=alt.value(colorbar_height - 2)) + ) + + # Add value labels below tick marks + text_labels = ( + alt.Chart(axis_data) + .mark_text(baseline="top", fontSize=10, dy=0) + .encode( + x=alt.X("x:Q"), + text=alt.Text("value:Q", format=".1f"), + y=alt.value(colorbar_height + 10), + ) + ) + + # Add title to colorbar + title = ( + alt.Chart(pd.DataFrame([{"text": layer_name}])) + .mark_text( + fontSize=12, + fontWeight="bold", + baseline="bottom", + align="center", + ) + .encode( + text="text:N", + x=alt.value(colorbar_width / 2), + y=alt.value(colorbar_height + 40), + ) + ) + + # Combine all colorbar components + combined_colorbar = alt.layer( + colorbar_chart, axis_chart, text_labels, title + ).properties(width=colorbar_width, height=colorbar_height + 50) + + bar_chart = ( + alt.vconcat(bar_chart, combined_colorbar) + .resolve_scale(color="independent") + .configure_view(stroke=None) + if bar_chart is not None + else combined_colorbar + ) + + elif "colormap" in portrayal: + cmap = portrayal.get("colormap", "viridis") + cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax]) + + chart = ( + alt.Chart(df) + .mark_rect(opacity=alpha) + .encode( + x=alt.X("x:O", axis=None), + y=alt.Y("y:O", axis=None), + color=alt.Color( + "value:Q", + scale=cmap_scale, + title=layer_name, + legend=alt.Legend(title=layer_name) if colorbar else None, + ), + ) + .properties(width=chart_width, height=chart_height) + ) + base = alt.layer(chart, base) if base is not None else chart + + else: + raise ValueError( + f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." + ) + return base, bar_chart diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 6e25502e0b7..4621e6a671f 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -9,7 +9,7 @@ import mesa import mesa.visualization.components.altair_components import mesa.visualization.components.matplotlib_components -from mesa.space import MultiGrid +from mesa.space import MultiGrid, PropertyLayer from mesa.visualization.components.altair_components import make_altair_space from mesa.visualization.components.matplotlib_components import make_mpl_space_component from mesa.visualization.solara_viz import ( @@ -102,6 +102,9 @@ def test_call_space_drawer(mocker): # noqa: D103 mock_space_altair = mocker.spy( mesa.visualization.components.altair_components, "SpaceAltair" ) + mock_chart_property_layer = mocker.spy( + mesa.visualization.components.altair_components, "chart_property_layers" + ) class MockAgent(mesa.Agent): def __init__(self, model): @@ -110,7 +113,12 @@ def __init__(self, model): class MockModel(mesa.Model): def __init__(self, seed=None): super().__init__(seed=seed) - self.grid = MultiGrid(width=10, height=10, torus=True) + layer1 = PropertyLayer( + name="sugar", width=10, height=10, default_value=10.0 + ) + self.grid = MultiGrid( + width=10, height=10, torus=True, property_layers=layer1 + ) a = MockAgent(self) self.grid.place_agent(a, (5, 5)) @@ -141,7 +149,15 @@ def agent_portrayal(agent): assert mock_space_altair.call_count == 1 # altair is the default method # checking if SpaceAltair is working as intended with post_process - + propertylayer_portrayal = { + "sugar": { + "colormap": "pastel1", + "alpha": 0.75, + "colorbar": True, + "vmin": 0, + "vmax": 10, + } + } mock_post_process = mocker.MagicMock() solara.render( SolaraViz( @@ -149,8 +165,8 @@ def agent_portrayal(agent): components=[ make_altair_space( agent_portrayal, - propertylayer_portrayal, - mock_post_process, + post_process=mock_post_process, + propertylayer_portrayal=propertylayer_portrayal, ) ], ) @@ -158,13 +174,18 @@ def agent_portrayal(agent): args, kwargs = mock_space_altair.call_args assert args == (model, agent_portrayal) - assert kwargs == {"post_process": mock_post_process} + assert kwargs == { + "post_process": mock_post_process, + "propertylayer_portrayal": propertylayer_portrayal, + } mock_post_process.assert_called_once() + assert mock_chart_property_layer.call_count == 1 assert mock_space_matplotlib.call_count == 0 mock_space_altair.reset_mock() mock_space_matplotlib.reset_mock() mock_post_process.reset_mock() + mock_chart_property_layer.reset_mock() # specify a custom space method class AltSpace: @@ -178,7 +199,7 @@ def drawer(model): # check voronoi space drawer voronoi_model = mesa.Model() - voronoi_model.grid = mesa.experimental.cell_space.VoronoiGrid( + voronoi_model.grid = mesa.discrete_space.VoronoiGrid( centroids_coordinates=[(0, 1), (0, 0), (1, 0)], ) solara.render(