Skip to content

Commit 643b684

Browse files
committed
Split ModelControls into BaseControls and TimelineControls
1 parent 16f3324 commit 643b684

File tree

1 file changed

+60
-31
lines changed

1 file changed

+60
-31
lines changed

mesa/experimental/jupyter_viz.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
from typing import Optional
23

34
import matplotlib.pyplot as plt
45
import networkx as nx
@@ -22,6 +23,7 @@ def JupyterViz(
2223
agent_portrayal=None,
2324
space_drawer="default",
2425
play_interval=400,
26+
timeline=False,
2527
):
2628
"""Initialize a component to visualize a model.
2729
Args:
@@ -35,6 +37,7 @@ def JupyterViz(
3537
simulations with no space to visualize should
3638
specify `space_drawer=False`
3739
play_interval: play interval (default: 400)
40+
timeline: whether to display a scrubbable timeline (default: False)
3841
"""
3942

4043
# 1. Set up model parameters
@@ -43,10 +46,7 @@ def JupyterViz(
4346
{**fixed_params, **{k: v["value"] for k, v in user_params.items()}}
4447
)
4548

46-
model, set_model = solara.use_state(
47-
model_class(**model_parameters),
48-
eq=model_eq,
49-
)
49+
model, set_model = solara.use_state(model_class(**model_parameters))
5050
model_cache, set_model_cache = solara.use_state({0: model})
5151

5252
# 2. Set up Model
@@ -57,32 +57,45 @@ def make_model():
5757

5858
solara.use_memo(
5959
make_model,
60-
dependencies=[
61-
list(model_parameters.values()),
62-
],
60+
dependencies=list(model_parameters.values()),
6361
)
6462

6563
def handle_change_model_params(name: str, value: any):
6664
set_model_parameters({**model_parameters, name: value})
6765

68-
def handle_step(step: int):
66+
def handle_step(step: Optional[int] = None):
67+
"""Change the model to the next step, or to the specified step.
68+
69+
If step is specified, the model is cached at that step.
70+
71+
Args:
72+
step: step to change the model to
73+
"""
6974
if not model.running:
7075
return
76+
7177
if step in model_cache:
72-
previous_model = model_cache[step]
73-
set_model(previous_model)
78+
updated_model = model_cache[step]
7479
else:
75-
model.step()
76-
set_model_cache({**model_cache, step: copy.deepcopy(model)})
77-
set_model(model)
80+
updated_model = copy.deepcopy(model)
81+
updated_model.step()
82+
83+
if step is not None:
84+
set_model_cache({**model_cache, step: updated_model})
85+
86+
set_model(updated_model)
7887

7988
# 3. Set up UI
8089
solara.Markdown(name)
8190
UserInputs(user_params, on_change=handle_change_model_params)
82-
ModelControls(
91+
TimelineControls(
8392
play_interval=play_interval,
93+
on_step=handle_step,
94+
on_reset=make_model,
8495
current_step=model.schedule.steps,
8596
max_step=max(model_cache.keys()),
97+
) if timeline else BaseControls(
98+
play_interval=play_interval,
8699
on_step=handle_step,
87100
on_reset=make_model,
88101
)
@@ -106,8 +119,37 @@ def handle_step(step: int):
106119
make_plot(model, measure)
107120

108121

122+
def BaseControls(play_interval, on_step, on_reset):
123+
playing = solara.use_reactive(False)
124+
125+
def on_value_play(_):
126+
if playing.value:
127+
on_step()
128+
129+
def reset():
130+
playing.value = False
131+
on_reset()
132+
133+
with solara.Card(), solara.Row(gap="2px", style={"align-items": "center"}):
134+
with solara.Tooltip("Reset the model"):
135+
solara.Button(icon_name="mdi-reload", color="primary", on_click=reset)
136+
137+
with solara.Tooltip("Step forward"):
138+
solara.Button(label="+1", color="primary", on_click=on_step)
139+
140+
with solara.Tooltip("Start/Stop the model"):
141+
widgets.Play(
142+
interval=play_interval,
143+
show_repeat=False,
144+
on_value=on_value_play,
145+
playing=playing.value,
146+
on_playing=playing.set,
147+
layout=widgets.Layout(height="36px"),
148+
)
149+
150+
109151
@solara.component
110-
def ModelControls(play_interval, current_step, max_step, on_step, on_reset):
152+
def TimelineControls(play_interval, on_step, on_reset, current_step, max_step):
111153
playing = solara.use_reactive(False)
112154

113155
def on_value_play(_):
@@ -135,31 +177,22 @@ def reset():
135177
)
136178
with solara.Tooltip("Step backward"):
137179
solara.Button(
138-
label="|◀",
180+
label="-1",
139181
color="primary",
140182
on_click=change_step(current_step - 1),
141183
)
142-
143-
# This style is necessary so that the play widget has the same
144-
# height as typical Solara buttons.
145-
solara.Style(
146-
"""
147-
.widget-play {
148-
height: 36px;
149-
}
150-
"""
151-
)
152184
with solara.Tooltip("Start/Stop the model"):
153185
widgets.Play(
154186
interval=play_interval,
155187
show_repeat=False,
156188
on_value=on_value_play,
157189
playing=playing.value,
158190
on_playing=playing.set,
191+
layout=widgets.Layout(height="36px"),
159192
)
160193
with solara.Tooltip("Step forward"):
161194
solara.Button(
162-
label="▶|",
195+
label="+1",
163196
color="primary",
164197
on_click=change_step(current_step + 1),
165198
)
@@ -180,10 +213,6 @@ def reset():
180213
)
181214

182215

183-
def model_eq(a, b):
184-
return id(a) == id(b) and (a.schedule.steps == b.schedule.steps)
185-
186-
187216
def split_model_params(model_params):
188217
model_params_input = {}
189218
model_params_fixed = {}

0 commit comments

Comments
 (0)