Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 240 additions & 13 deletions mesa/visualization/components/altair_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import warnings

import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import solara
from matplotlib.colors import to_rgb

import mesa
from mesa.discrete_space import DiscreteSpace, Grid
from mesa.space import ContinuousSpace, _Grid
from mesa.space import ContinuousSpace, PropertyLayer, _Grid
from mesa.visualization.utils import update_counter


Expand All @@ -26,7 +31,7 @@ def make_altair_space(

Args:
agent_portrayal: Function to portray agents.
propertylayer_portrayal: not yet implemented
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks)
space_drawing_kwargs : not yet implemented

Expand All @@ -43,14 +48,20 @@ def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceAltair(model):
return SpaceAltair(model, agent_portrayal, post_process=post_process)
return SpaceAltair(
model, agent_portrayal, propertylayer_portrayal, post_process=post_process
)

return MakeSpaceAltair


@solara.component
def SpaceAltair(
model, agent_portrayal, dependencies: list[any] | None = None, post_process=None
model,
agent_portrayal,
propertylayer_portrayal,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is use of property layer mandatory? (i.e. should this be a keyword argument?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, just fixed it

dependencies: list[any] | None = None,
post_process=None,
):
"""Create an Altair-based space visualization component.

Expand All @@ -63,10 +74,11 @@ def SpaceAltair(
# Sometimes the space is defined as model.space instead of model.grid
space = model.space

chart = _draw_grid(space, agent_portrayal)
chart = _draw_grid(space, agent_portrayal, propertylayer_portrayal)
# Apply post-processing if provided
if post_process is not None:
chart = post_process(chart)

solara.FigureAltair(chart)


Expand Down Expand Up @@ -138,7 +150,7 @@ def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal):
return all_agent_data


def _draw_grid(space, agent_portrayal):
def _draw_grid(space, agent_portrayal, propertylayer_portrayal):
match space:
case Grid():
all_agent_data = _get_agent_data_new_discrete_space(space, agent_portrayal)
Expand Down Expand Up @@ -168,23 +180,238 @@ def _draw_grid(space, agent_portrayal):
}
has_color = "color" in all_agent_data[0]
if has_color:
encoding_dict["color"] = alt.Color("color", type="nominal")
unique_colors = list({agent["color"] for agent in all_agent_data})
encoding_dict["color"] = alt.Color(
"color:N",
scale=alt.Scale(domain=unique_colors, range=unique_colors),
)
has_size = "size" in all_agent_data[0]
if has_size:
encoding_dict["size"] = alt.Size("size", type="quantitative")

chart = (
agent_chart = (
alt.Chart(
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
)
.mark_point(filled=True)
.properties(width=280, height=280)
# .configure_view(strokeOpacity=0) # hide grid/chart lines
.properties(width=300, height=300)
)
# This is the default value for the marker size, which auto-scales
# according to the grid area.

# This is the default value for the marker size, which auto-scales according to the grid area.
if not has_size:
length = min(space.width, space.height)
chart = chart.mark_point(size=30000 / length**2, filled=True)
agent_chart = agent_chart.mark_point(size=30000 / length**2, filled=True)

if propertylayer_portrayal is not None:
chart_width = agent_chart.properties().width
chart_height = agent_chart.properties().height
chart = chart_property_layers(
space=space,
propertylayer_portrayal=propertylayer_portrayal,
chart_width=chart_width,
chart_height=chart_height,
)
chart = chart + agent_chart
else:
chart = agent_chart

return chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, chart does not exist here, you will have to declare a chart variable above to use it here. The chart declared inside the if statement get destroyed with it.

Copy link
Collaborator Author

@sanika-n sanika-n Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong here but from what I know, in languages like C++ a variable defined inside a loop only exists within that loop’s scope but in python I am fairly sure that variables defined in loops remain accessible outside the loop and since I am defining chart both in the if and else part of the loop, it is definitely going to be defined by the time we reach the return line.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're right that Python variables from if/else blocks remain accessible afterward, it's safer to initialize chart at the function level first. This ensures it's always defined regardless of execution path. Could you update your code to follow this pattern? It prevents potential undefined variable issues if your conditions change later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okk, that makes sense, will change it 👍



def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_height):
"""Creates Property Layers in the Altair Components.

Args:
space: the ContinuousSpace instance
propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
chart_width: width of the agent chart to maintain consistency with the property charts
chart_height: height of the agent chart to maintain consistency with the property charts
Returns:
Altair Chart
"""
try:
# old style spaces
property_layers = space.properties
except AttributeError:
# new style spaces
property_layers = space._mesa_property_layers
base = None
for layer_name, portrayal in propertylayer_portrayal.items():
layer = property_layers.get(layer_name, None)
if not isinstance(
layer,
PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer,
):
continue

data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

if (space.width, space.height) != data.shape:
warnings.warn(
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).",
UserWarning,
stacklevel=2,
)
alpha = portrayal.get("alpha", 1)
vmin = portrayal.get("vmin", np.min(data))
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

# Prepare data for Altair (convert 2D array to a long-form DataFrame)
df = pd.DataFrame(
{
"x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
"y": np.tile(np.arange(data.shape[1]), data.shape[0]),
"value": data.flatten(),
}
)

if "color" in portrayal:
# any value less than vmin will be mapped to the color corresponding to vmin
# any value more than vmax will be mapped to the color corresponding to vmax
def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
a = (val - vmin) / (vmax - vmin)
a = max(0, min(a, 1)) # to ensure that a is between 0 and 1
a *= alpha # vmax will have an opacity corresponding to alpha
rgb_color = to_rgb(portrayal["color"])
r = int(rgb_color[0] * 255)
g = int(rgb_color[1] * 255)
b = int(rgb_color[2] * 255)

return f"rgba({r}, {g}, {b}, {a:.2f})"

df["color"] = df["value"].apply(apply_rgba)

chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
fill=alt.Fill("color:N", scale=None),
)
.properties(width=chart_width, height=chart_height, title=layer_name)
)
base = (base + chart) if base is not None else chart
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are overcomplicating it a bit, this should work:

Suggested change
if "color" in portrayal:
# any value less than vmin will be mapped to the color corresponding to vmin
# any value more than vmax will be mapped to the color corresponding to vmax
def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
a = (val - vmin) / (vmax - vmin)
a = max(0, min(a, 1)) # to ensure that a is between 0 and 1
a *= alpha # vmax will have an opacity corresponding to alpha
rgb_color = to_rgb(portrayal["color"])
r = int(rgb_color[0] * 255)
g = int(rgb_color[1] * 255)
b = int(rgb_color[2] * 255)
return f"rgba({r}, {g}, {b}, {a:.2f})"
df["color"] = df["value"].apply(apply_rgba)
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
fill=alt.Fill("color:N", scale=None),
)
.properties(width=chart_width, height=chart_height, title=layer_name)
)
base = (base + chart) if base is not None else chart
import matplotlib.colors as mcolors
if "color" in portrayal:
color = portrayal["color"]
# Convert the user color + alpha to RGBA
rgba = mcolors.to_rgba(color, alpha=alpha)
layer_chart = (
alt.Chart(df)
.mark_rect(
color=f"rgba({int(rgba[0] * 255)}, "
f"{int(rgba[1] * 255)}, "
f"{int(rgba[2] * 255)}, "
f"{rgba[3]})"
)
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
opacity=alt.Opacity(
"value:Q",
scale=alt.Scale(domain=[vmin, vmax], range=[0, 1])
)
)
.properties(width=chart_width, height=chart_height, title=layer_name)
)

Colorbar should be implemented in altair as well, though I am also not very sure how will that work, I have been trying for the past hour to get it right, but either its overcomplicating the code or its not working.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried this out, but I am not able to create a color bar when I am using the code you have suggested as I don't know the mapping of value to color when the Scale function is used and the inbuilt altair legend is discrete and not continous

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but the matplotlib implementation of just the colorbar is also not right.


if colorbar:
list_value = []
list_color = []

i = vmin

while i <= vmax:
list_value.append(i)
list_color.append(apply_rgba(i))
i += 1

if vmax not in list_value:
list_value.append(vmax)
list_color.append(apply_rgba(vmax))
df_colorbar = pd.DataFrame(
{
"value": list_value,
"color": list_color,
}
)

x_values = np.array(df_colorbar["value"])
rgba_colors = np.array(df_colorbar["color"])
# Ensure rgba_colors is a 2D array
if rgba_colors.ndim == 1:
rgba_colors = np.array(
[list(color) for color in rgba_colors]
) # Convert tuples to a 2D array

def parse_rgba(color_str):
if isinstance(color_str, str) and color_str.startswith("rgba"):
color_str = (
color_str.replace("rgba(", "").replace(")", "").split(",")
)
return np.array(
[
float(color_str[i]) / 255
if i < 3
else float(color_str[i])
for i in range(4)
],
dtype=float,
)
return np.array(
color_str, dtype=float
) # If already a tuple, convert to float

# Convert color strings to RGBA tuples (ensures correct dtype)
rgba_colors = np.array(
[parse_rgba(c) for c in df_colorbar["color"]], dtype=float
)

# Ensure rgba_colors is a 2D array with shape (n, 4)
rgba_colors = np.array(rgba_colors).reshape(-1, 4)

# Create an RGBA gradient image (256 steps for smooth transition)
gradient = np.zeros((50, 256, 4)) # (Height, Width, RGBA)

# Interpolate each channel (R, G, B, A) separately
interp_r = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 0],
)
interp_g = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 1],
)
interp_b = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 2],
)
interp_a = np.interp(
np.linspace(0, 255, 256),
np.linspace(0, 255, len(rgba_colors)),
rgba_colors[:, 3],
)

interp_colors = np.stack(
[interp_r, interp_g, interp_b, interp_a], axis=-1
)
gradient[:] = interp_colors
fig, ax = plt.subplots(figsize=(6, 0.25), dpi=100)
ax.imshow(
gradient,
aspect="auto",
extent=[x_values.min(), x_values.max(), 0, 1],
)
ax.set_yticks([])
ax.set_xlabel(layer_name)
ax.set_xticks(np.linspace(x_values.min(), x_values.max(), 11))
plt.show()

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

chart = (
Comment on lines +425 to +429
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you apply the alpha to colormaps as well. I think .mark_rect(opacity=alpha) should do the job.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out, I totally forgot to implement it...

alt.Chart(df)
.mark_rect(opacity=alpha)
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color(
"value:Q",
scale=cmap_scale,
title=layer_name,
legend=alt.Legend(title=layer_name) if colorbar else None,
),
)
.properties(width=chart_width, height=chart_height)
)
base = (base + chart) if base is not None else chart

else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)
return base
24 changes: 20 additions & 4 deletions tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import mesa
import mesa.visualization.components.altair_components
import mesa.visualization.components.matplotlib_components
from mesa.space import MultiGrid
from mesa.space import MultiGrid, PropertyLayer
from mesa.visualization.components.altair_components import make_altair_space
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
from mesa.visualization.solara_viz import (
Expand Down Expand Up @@ -102,6 +102,9 @@ def test_call_space_drawer(mocker): # noqa: D103
mock_space_altair = mocker.spy(
mesa.visualization.components.altair_components, "SpaceAltair"
)
mock_chart_property_layer = mocker.spy(
mesa.visualization.components.altair_components, "chart_property_layers"
)

class MockAgent(mesa.Agent):
def __init__(self, model):
Expand All @@ -110,7 +113,10 @@ def __init__(self, model):
class MockModel(mesa.Model):
def __init__(self, seed=None):
super().__init__(seed=seed)
self.grid = MultiGrid(width=10, height=10, torus=True)
layer1 = PropertyLayer(name="sugar", width=10, height=10, default_value=10)
self.grid = MultiGrid(
width=10, height=10, torus=True, property_layers=layer1
)
a = MockAgent(self)
self.grid.place_agent(a, (5, 5))

Expand Down Expand Up @@ -141,7 +147,15 @@ def agent_portrayal(agent):
assert mock_space_altair.call_count == 1 # altair is the default method

# checking if SpaceAltair is working as intended with post_process

propertylayer_portrayal = {
"sugar": {
"colormap": "pastel1",
"alpha": 0.75,
"colorbar": True,
"vmin": 0,
"vmax": 10,
}
}
mock_post_process = mocker.MagicMock()
solara.render(
SolaraViz(
Expand All @@ -157,14 +171,16 @@ def agent_portrayal(agent):
)

args, kwargs = mock_space_altair.call_args
assert args == (model, agent_portrayal)
assert args == (model, agent_portrayal, propertylayer_portrayal)
assert kwargs == {"post_process": mock_post_process}
mock_post_process.assert_called_once()
assert mock_chart_property_layer.call_count == 1
assert mock_space_matplotlib.call_count == 0

mock_space_altair.reset_mock()
mock_space_matplotlib.reset_mock()
mock_post_process.reset_mock()
mock_chart_property_layer.reset_mock()

# specify a custom space method
class AltSpace:
Expand Down
Loading