diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 6007b9a8c8f..4aef81d2d3b 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -48,10 +48,7 @@ def make_model(): with solara.GridFixed(columns=2): # 4. Space - if space_drawer is None: - make_space(model, agent_portrayal) - else: - space_drawer(model, agent_portrayal) + SpaceView(space_drawer, model, agent_portrayal, current_step) # 5. Plots for measure in measures: if callable(measure): @@ -182,15 +179,16 @@ def make_user_input(user_input, name, options): raise ValueError(f"{input_type} is not a supported input type") -def make_space(model, agent_portrayal): - def portray(g): +@solara.component +def GridView(grid, agent_portrayal, dummy): + def portray(grid): x = [] y = [] - s = [] # size - c = [] # color - for i in range(g.width): - for j in range(g.height): - content = g._grid[i][j] + sizes = [] + colors = [] + for i in range(grid.width): + for j in range(grid.height): + content = grid._grid[i][j] if not content: continue if not hasattr(content, "__iter__"): @@ -201,28 +199,41 @@ def portray(g): x.append(i) y.append(j) if "size" in data: - s.append(data["size"]) + sizes.append(data["size"]) if "color" in data: - c.append(data["color"]) + colors.append(data["color"]) out = {"x": x, "y": y} - if len(s) > 0: - out["s"] = s - if len(c) > 0: - out["c"] = c + if len(sizes) > 0: + out["s"] = sizes + if len(colors) > 0: + out["c"] = colors return out space_fig = Figure() space_ax = space_fig.subplots() - if isinstance(model.grid, mesa.space.NetworkGrid): - _draw_network_grid(model, space_ax, agent_portrayal) - else: - space_ax.scatter(**portray(model.grid)) + space_ax.scatter(**portray(grid)) space_ax.set_axis_off() solara.FigureMatplotlib(space_fig) -def _draw_network_grid(model, space_ax, agent_portrayal): - graph = model.grid.G +@solara.component +def SpaceView(space_drawer, model, agent_portrayal, dummy): + if space_drawer is not None: + return space_drawer(model, agent_portrayal) + + if isinstance(model.grid, mesa.space.NetworkGrid): + return NetworkView(model.grid.G, agent_portrayal, dummy) + + if isinstance(model.grid, mesa.space._Grid): + return GridView(model.grid, agent_portrayal, dummy) + + raise ValueError(f"Unsupported space type: {type(model.grid)}") + + +@solara.component +def NetworkView(graph, agent_portrayal, dummy): + space_fig = Figure() + space_ax = space_fig.subplots() pos = nx.spring_layout(graph, seed=0) nx.draw( graph, @@ -230,6 +241,8 @@ def _draw_network_grid(model, space_ax, agent_portrayal): pos=pos, **agent_portrayal(graph), ) + space_ax.set_axis_off() + solara.FigureMatplotlib(space_fig) def make_plot(model, measure):