From 4aa02fd112abb446f1dc9623d5a9ca3c96a202bc Mon Sep 17 00:00:00 2001 From: Corvince Date: Sun, 3 Sep 2023 21:08:30 +0200 Subject: [PATCH 1/2] Convert make_space into solara components --- mesa/experimental/jupyter_viz.py | 63 ++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 6007b9a8c8f..8487741f8f4 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -48,10 +48,7 @@ def make_model(): with solara.GridFixed(columns=2): # 4. Space - if space_drawer is None: - make_space(model, agent_portrayal) - else: - space_drawer(model, agent_portrayal) + SpaceView(space_drawer, model, agent_portrayal) # 5. Plots for measure in measures: if callable(measure): @@ -182,15 +179,16 @@ def make_user_input(user_input, name, options): raise ValueError(f"{input_type} is not a supported input type") -def make_space(model, agent_portrayal): - def portray(g): +@solara.component +def GridView(grid, agent_portrayal): + def portray(grid): x = [] y = [] - s = [] # size - c = [] # color - for i in range(g.width): - for j in range(g.height): - content = g._grid[i][j] + sizes = [] + colors = [] + for i in range(grid.width): + for j in range(grid.height): + content = grid._grid[i][j] if not content: continue if not hasattr(content, "__iter__"): @@ -201,28 +199,45 @@ def portray(g): x.append(i) y.append(j) if "size" in data: - s.append(data["size"]) + sizes.append(data["size"]) if "color" in data: - c.append(data["color"]) + colors.append(data["color"]) out = {"x": x, "y": y} - if len(s) > 0: - out["s"] = s - if len(c) > 0: - out["c"] = c + if len(sizes) > 0: + out["s"] = sizes + if len(colors) > 0: + out["c"] = colors return out space_fig = Figure() space_ax = space_fig.subplots() - if isinstance(model.grid, mesa.space.NetworkGrid): - _draw_network_grid(model, space_ax, agent_portrayal) - else: - space_ax.scatter(**portray(model.grid)) + space_ax.scatter(**portray(grid)) space_ax.set_axis_off() solara.FigureMatplotlib(space_fig) -def _draw_network_grid(model, space_ax, agent_portrayal): - graph = model.grid.G +@solara.component +def SpaceView( + space_drawer, + model, + agent_portrayal, +): + if space_drawer is not None: + return space_drawer(model, agent_portrayal) + + if isinstance(model.grid, mesa.space.NetworkGrid): + return NetworkSpace(model.grid.G, agent_portrayal) + + if isinstance(model.grid, mesa.space._Grid): + return GridView(model.grid, agent_portrayal) + + raise ValueError(f"Unsupported space type: {type(model.grid)}") + + +@solara.component +def NetworkSpace(graph, agent_portrayal): + space_fig = Figure() + space_ax = space_fig.subplots() pos = nx.spring_layout(graph, seed=0) nx.draw( graph, @@ -230,6 +245,8 @@ def _draw_network_grid(model, space_ax, agent_portrayal): pos=pos, **agent_portrayal(graph), ) + space_ax.set_axis_off() + solara.FigureMatplotlib(space_fig) def make_plot(model, measure): From 6d73151fcb94e64c7dc8bc4caa48d04eeda4bd5b Mon Sep 17 00:00:00 2001 From: Corvince Date: Mon, 4 Sep 2023 10:04:50 +0200 Subject: [PATCH 2/2] Add dummy parameter to trigger rerenders --- mesa/experimental/jupyter_viz.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 8487741f8f4..4aef81d2d3b 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -48,7 +48,7 @@ def make_model(): with solara.GridFixed(columns=2): # 4. Space - SpaceView(space_drawer, model, agent_portrayal) + SpaceView(space_drawer, model, agent_portrayal, current_step) # 5. Plots for measure in measures: if callable(measure): @@ -180,7 +180,7 @@ def make_user_input(user_input, name, options): @solara.component -def GridView(grid, agent_portrayal): +def GridView(grid, agent_portrayal, dummy): def portray(grid): x = [] y = [] @@ -217,25 +217,21 @@ def portray(grid): @solara.component -def SpaceView( - space_drawer, - model, - agent_portrayal, -): +def SpaceView(space_drawer, model, agent_portrayal, dummy): if space_drawer is not None: return space_drawer(model, agent_portrayal) if isinstance(model.grid, mesa.space.NetworkGrid): - return NetworkSpace(model.grid.G, agent_portrayal) + return NetworkView(model.grid.G, agent_portrayal, dummy) if isinstance(model.grid, mesa.space._Grid): - return GridView(model.grid, agent_portrayal) + return GridView(model.grid, agent_portrayal, dummy) raise ValueError(f"Unsupported space type: {type(model.grid)}") @solara.component -def NetworkSpace(graph, agent_portrayal): +def NetworkView(graph, agent_portrayal, dummy): space_fig = Figure() space_ax = space_fig.subplots() pos = nx.spring_layout(graph, seed=0)