Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions mesa/experimental/jupyter_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def make_model():

with solara.GridFixed(columns=2):
# 4. Space
if space_drawer is None:
make_space(model, agent_portrayal)
else:
space_drawer(model, agent_portrayal)
SpaceView(space_drawer, model, agent_portrayal, current_step)
# 5. Plots
for measure in measures:
if callable(measure):
Expand Down Expand Up @@ -182,15 +179,16 @@ def make_user_input(user_input, name, options):
raise ValueError(f"{input_type} is not a supported input type")


def make_space(model, agent_portrayal):
def portray(g):
@solara.component
def GridView(grid, agent_portrayal, dummy):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dummy as a name is rather non-descriptive and confusing.

def portray(grid):
x = []
y = []
s = [] # size
c = [] # color
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
sizes = []
colors = []
for i in range(grid.width):
for j in range(grid.height):
content = grid._grid[i][j]
if not content:
continue
if not hasattr(content, "__iter__"):
Expand All @@ -201,35 +199,50 @@ def portray(g):
x.append(i)
y.append(j)
if "size" in data:
s.append(data["size"])
sizes.append(data["size"])
if "color" in data:
c.append(data["color"])
colors.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c
if len(sizes) > 0:
out["s"] = sizes
if len(colors) > 0:
out["c"] = colors
return out

space_fig = Figure()
space_ax = space_fig.subplots()
if isinstance(model.grid, mesa.space.NetworkGrid):
_draw_network_grid(model, space_ax, agent_portrayal)
else:
space_ax.scatter(**portray(model.grid))
space_ax.scatter(**portray(grid))
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig)


def _draw_network_grid(model, space_ax, agent_portrayal):
graph = model.grid.G
@solara.component
def SpaceView(space_drawer, model, agent_portrayal, dummy):
if space_drawer is not None:
return space_drawer(model, agent_portrayal)

if isinstance(model.grid, mesa.space.NetworkGrid):
return NetworkView(model.grid.G, agent_portrayal, dummy)

if isinstance(model.grid, mesa.space._Grid):
return GridView(model.grid, agent_portrayal, dummy)

raise ValueError(f"Unsupported space type: {type(model.grid)}")


@solara.component
def NetworkView(graph, agent_portrayal, dummy):
space_fig = Figure()
space_ax = space_fig.subplots()
pos = nx.spring_layout(graph, seed=0)
nx.draw(
graph,
ax=space_ax,
pos=pos,
**agent_portrayal(graph),
)
space_ax.set_axis_off()
solara.FigureMatplotlib(space_fig)


def make_plot(model, measure):
Expand Down