Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mesa/examples/basic/boltzmann_wealth_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

def agent_portrayal(agent):
return AgentPortrayalStyle(
color=agent.wealth
) # we are using a colormap to translate wealth to color
color=agent.wealth,
tooltip={"Agent ID": agent.unique_id, "Wealth": agent.wealth},
)


model_params = {
Expand All @@ -41,7 +42,7 @@ def post_process(chart):
"""Post-process the Altair chart to add a colorbar legend."""
chart = chart.encode(
color=alt.Color(
"color:N",
"original_color:Q",
scale=alt.Scale(scheme="viridis", domain=[0, 10]),
legend=alt.Legend(
title="Wealth",
Expand All @@ -63,12 +64,12 @@ def post_process(chart):
renderer = SpaceRenderer(model, backend="altair")
# Can customize the grid appearance.
renderer.draw_structure(grid_color="black", grid_dash=[6, 2], grid_opacity=0.3)
renderer.draw_agents(agent_portrayal=agent_portrayal, cmap="viridis", vmin=0, vmax=10)

renderer.draw_agents(agent_portrayal=agent_portrayal)
# The post_process function is used to modify the Altair chart after it has been created.
# It can be used to add legends, colorbars, or other visual elements.
renderer.post_process = post_process


# Creates a line plot component from the model's "Gini" datacollector.
GiniPlot = make_plot_component("Gini")

Expand Down
128 changes: 64 additions & 64 deletions mesa/visualization/backends/altair_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# noqa: D100
"""Altair-based renderer for Mesa spaces.

This module provides an Altair-based renderer for visualizing Mesa model spaces,
agents, and property layers with interactive charting capabilities.
"""

import warnings
from collections.abc import Callable
from dataclasses import fields
Expand Down Expand Up @@ -75,6 +80,7 @@ def collect_agent_data(
"stroke": [], # Stroke color
"strokeWidth": [],
"filled": [],
"tooltip": [],
}

# Import here to avoid circular import issues
Expand Down Expand Up @@ -129,6 +135,7 @@ def collect_agent_data(
linewidths=dict_data.pop(
"linewidths", style_fields.get("linewidths")
),
tooltip=dict_data.pop("tooltip", None),
)
if dict_data:
ignored_keys = list(dict_data.keys())
Expand Down Expand Up @@ -184,6 +191,7 @@ def collect_agent_data(
# FIXME: Make filled user-controllable
filled_value = True
arguments["filled"].append(filled_value)
arguments["tooltip"].append(aps.tooltip)

final_data = {}
for k, v in arguments.items():
Expand Down Expand Up @@ -217,79 +225,71 @@ def draw_agents(
if arguments["loc"].size == 0:
return None

# To get a continuous scale for color the domain should be between [0, 1]
# that's why changing the the domain of strokeWidth beforehand.
stroke_width = [data / 10 for data in arguments["strokeWidth"]]

# Agent data preparation
df_data = {
"x": arguments["loc"][:, 0],
"y": arguments["loc"][:, 1],
"size": arguments["size"],
"shape": arguments["shape"],
"opacity": arguments["opacity"],
"strokeWidth": stroke_width,
"original_color": arguments["color"],
"is_filled": arguments["filled"],
"original_stroke": arguments["stroke"],
}
df = pd.DataFrame(df_data)

# To ensure distinct shapes according to agent portrayal
unique_shape_names_in_data = df["shape"].unique().tolist()

fill_colors = []
stroke_colors = []
for i in range(len(df)):
filled = df["is_filled"][i]
main_color = df["original_color"][i]
stroke_spec = (
df["original_stroke"][i]
if isinstance(df["original_stroke"][i], str)
else None
)
if filled:
fill_colors.append(main_color)
stroke_colors.append(stroke_spec)
Comment on lines -221 to -253
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, it seems that more is changed then just adding a tooltip. Can you explain what changed here and why?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous approach of creating separate DataFrames and joining them (df.join(tooltip_df)) was causing a ValueError: Dataframe contains invalid column name: 0. Pandas was creating integer-based column names, which Altair cannot handle.The new code fixes this by building a list of dictionaries (records), where each dictionary represents a single agent's complete data (position, style, and tooltip). Creating the DataFrame from this list (pd.DataFrame(records)) is a more robust method that ensures all column names are correctly handled as strings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sahil-Chhoker Can you review these changes? You are best positioned to judge whether this is all done correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a needed change to make the dataframe flexible but I've to do some testing before I can say anything.
@DipayanDasgupta can you just double check if you are not forgetting anything to include in the dataframe because it can be harder to spot later on.

# Prepare a list of dictionaries, which is a robust way to create a DataFrame
records = []
for i in range(len(arguments["loc"])):
record = {
"x": arguments["loc"][i][0],
"y": arguments["loc"][i][1],
"size": arguments["size"][i],
"shape": arguments["shape"][i],
"opacity": arguments["opacity"][i],
"strokeWidth": arguments["strokeWidth"][i]
/ 10, # Scale for continuous domain
"original_color": arguments["color"][i],
}
# Add tooltip data if available
tooltip = arguments["tooltip"][i]
if tooltip:
record.update(tooltip)

# Determine fill and stroke colors
if arguments["filled"][i]:
record["viz_fill_color"] = arguments["color"][i]
record["viz_stroke_color"] = (
arguments["stroke"][i]
if isinstance(arguments["stroke"][i], str)
else None
)
else:
fill_colors.append(None)
stroke_colors.append(main_color)
df["viz_fill_color"] = fill_colors
df["viz_stroke_color"] = stroke_colors
record["viz_fill_color"] = None
record["viz_stroke_color"] = arguments["color"][i]

records.append(record)

df = pd.DataFrame(records)

# Ensure all columns that should be numeric are, handling potential Nones
numeric_cols = ["x", "y", "size", "opacity", "strokeWidth", "original_color"]
for col in numeric_cols:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")

# Get tooltip keys from the first valid record
tooltip_list = ["x", "y"]
if any(t is not None for t in arguments["tooltip"]):
first_valid_tooltip = next(
(t for t in arguments["tooltip"] if t is not None), None
)
if first_valid_tooltip is not None:
tooltip_list.extend(first_valid_tooltip.keys())

# Extract additional parameters from kwargs
# FIXME: Add more parameters to kwargs
title = kwargs.pop("title", "")
xlabel = kwargs.pop("xlabel", "")
ylabel = kwargs.pop("ylabel", "")
# FIXME: Add more parameters to kwargs

# Tooltip list for interactivity
# FIXME: Add more fields to tooltip (preferably from agent_portrayal)
tooltip_list = ["x", "y"]

# Handle custom colormapping
cmap = kwargs.pop("cmap", "viridis")
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)

color_is_numeric = np.issubdtype(df["original_color"].dtype, np.number)
if color_is_numeric:
color_min = vmin if vmin is not None else df["original_color"].min()
color_max = vmax if vmax is not None else df["original_color"].max()

fill_encoding = alt.Fill(
"original_color:Q",
scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]),
)
else:
fill_encoding = alt.Fill(
"viz_fill_color:N",
scale=None,
title="Color",
)
color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"])
fill_encoding = (
alt.Fill("original_color:Q")
if color_is_numeric
else alt.Fill("viz_fill_color:N", scale=None, title="Color")
)
Comment on lines +263 to +288
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix color dtype detection for Altair fill.

pd.to_numeric is coercing string colors (e.g., "tab:blue") to NaN, leaving the column as float and making color_is_numeric truthy. Altair then encodes original_color:Q, so categorical colors disappear. Guard the numeric conversion and reuse that flag in the encoding.

-        numeric_cols = ["x", "y", "size", "opacity", "strokeWidth", "original_color"]
+        color_values = arguments["color"]
+        color_is_numeric = all(
+            isinstance(value, (int, float, np.number)) or value is None
+            for value in color_values
+        )
+
+        numeric_cols = ["x", "y", "size", "opacity", "strokeWidth"]
         for col in numeric_cols:
             if col in df.columns:
                 df[col] = pd.to_numeric(df[col], errors="coerce")
+        if color_is_numeric and "original_color" in df.columns:
+            df["original_color"] = pd.to_numeric(
+                df["original_color"], errors="coerce"
+            )
+        else:
+            color_is_numeric = False
@@
-        color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"])
         fill_encoding = (
             alt.Fill("original_color:Q")
             if color_is_numeric
             else alt.Fill("viz_fill_color:N", scale=None, title="Color")
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
numeric_cols = ["x", "y", "size", "opacity", "strokeWidth", "original_color"]
for col in numeric_cols:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
# Get tooltip keys from the first valid record
tooltip_list = ["x", "y"]
if any(t is not None for t in arguments["tooltip"]):
first_valid_tooltip = next(
(t for t in arguments["tooltip"] if t is not None), None
)
if first_valid_tooltip is not None:
tooltip_list.extend(first_valid_tooltip.keys())
# Extract additional parameters from kwargs
# FIXME: Add more parameters to kwargs
title = kwargs.pop("title", "")
xlabel = kwargs.pop("xlabel", "")
ylabel = kwargs.pop("ylabel", "")
# FIXME: Add more parameters to kwargs
# Tooltip list for interactivity
# FIXME: Add more fields to tooltip (preferably from agent_portrayal)
tooltip_list = ["x", "y"]
# Handle custom colormapping
cmap = kwargs.pop("cmap", "viridis")
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)
color_is_numeric = np.issubdtype(df["original_color"].dtype, np.number)
if color_is_numeric:
color_min = vmin if vmin is not None else df["original_color"].min()
color_max = vmax if vmax is not None else df["original_color"].max()
fill_encoding = alt.Fill(
"original_color:Q",
scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]),
)
else:
fill_encoding = alt.Fill(
"viz_fill_color:N",
scale=None,
title="Color",
)
color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"])
fill_encoding = (
alt.Fill("original_color:Q")
if color_is_numeric
else alt.Fill("viz_fill_color:N", scale=None, title="Color")
)
color_values = arguments["color"]
color_is_numeric = all(
isinstance(value, (int, float, np.number)) or value is None
for value in color_values
)
numeric_cols = ["x", "y", "size", "opacity", "strokeWidth"]
for col in numeric_cols:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
if color_is_numeric and "original_color" in df.columns:
df["original_color"] = pd.to_numeric(
df["original_color"], errors="coerce"
)
else:
color_is_numeric = False
# Get tooltip keys from the first valid record
tooltip_list = ["x", "y"]
if any(t is not None for t in arguments["tooltip"]):
first_valid_tooltip = next(
(t for t in arguments["tooltip"] if t is not None), None
)
if first_valid_tooltip is not None:
tooltip_list.extend(first_valid_tooltip.keys())
# Extract additional parameters from kwargs
title = kwargs.pop("title", "")
xlabel = kwargs.pop("xlabel", "")
ylabel = kwargs.pop("ylabel", "")
# FIXME: Add more parameters to kwargs
fill_encoding = (
alt.Fill("original_color:Q")
if color_is_numeric
else alt.Fill("viz_fill_color:N", scale=None, title="Color")
)
🤖 Prompt for AI Agents
In mesa/visualization/backends/altair_backend.py around lines 263-288, the code
coerces "original_color" to numeric which turns string color names into NaN and
makes color_is_numeric incorrectly truthy; fix it by NOT coercing original_color
when converting numeric_cols (remove "original_color" from numeric_cols) or by
computing color_is_numeric from the original column before any coercion (e.g.,
check dtype or try pd.to_numeric on a copy and inspect non-null fraction), then
use that computed boolean to choose alt.Fill("original_color:Q") vs
alt.Fill("viz_fill_color:N", ...); ensure the numeric conversion step only
affects true numeric columns and that the color_is_numeric flag is reused for
the encoding decision.


# Determine space dimensions
xmin, xmax, ymin, ymax = self.space_drawer.get_viz_limits()
unique_shape_names_in_data = df["shape"].dropna().unique().tolist()

chart = (
alt.Chart(df)
Expand Down
5 changes: 4 additions & 1 deletion mesa/visualization/backends/matplotlib_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid


CORRECTION_FACTOR_MARKER_ZOOM = 0.01


Expand Down Expand Up @@ -141,6 +140,10 @@ def collect_agent_data(self, space, agent_portrayal, default_size=None):
)
else:
aps = portray_input
if aps.tooltip is not None:
raise ValueError(
"The 'tooltip' attribute in AgentPortrayalStyle is only supported by the Altair backend."
)
# Set defaults if not provided
if aps.x is None and aps.y is None:
aps.x, aps.y = self._get_agent_pos(agent, space)
Expand Down
2 changes: 2 additions & 0 deletions mesa/visualization/components/portrayal_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class AgentPortrayalStyle:
alpha: float | None = 1.0
edgecolors: str | tuple | None = None
linewidths: float | int | None = 1.0
tooltip: dict | None = None
"""A dictionary of data to display on hover. Note: This feature is only available with the Altair backend."""

def update(self, *updates_fields: tuple[str, Any]):
"""Updates attributes from variable (field_name, new_value) tuple arguments.
Expand Down
1 change: 1 addition & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def test_altair_backend_draw_agents():
"color": np.array(["red", "blue"]),
"filled": np.array([True, True]),
"stroke": np.array(["black", "black"]),
"tooltip": np.array([None, None]),
}
ab.space_drawer.get_viz_limits = MagicMock(return_value=(0, 10, 0, 10))
assert ab.draw_agents(arguments) is not None
Expand Down