Skip to content

Commit c1b3ccd

Browse files
committed
wip: get first baselines again
1 parent 9517fb8 commit c1b3ccd

File tree

2 files changed

+42
-45
lines changed

2 files changed

+42
-45
lines changed

tests/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def update_baseline():
220220
results = load_json(results_fname)
221221
with open(baseline_fname, "w") as f:
222222
json.dump(results, f, indent=2)
223-
os.remove(results_fname)
223+
os.remove(results_fname)
224224

225225
def print_regression_report():
226226
baselines = load_json(baseline_fname)
@@ -239,5 +239,5 @@ def print_regression_report():
239239
print("\n\n\nRegression Test Report\n----------------------\n")
240240
print(report)
241241

242-
request.addfinalizer(print_regression_report)
243242
request.addfinalizer(update_baseline)
243+
request.addfinalizer(print_regression_report)

tests/test_regression.py

+40-43
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,18 @@ def test_wrapper(**kwargs):
132132

133133
append_to_json(fpath_results, header["test_name"], header["input_kwargs"], runtimes)
134134

135-
assert key in baselines, f"No basline found for {header}"
136-
func_baselines = baselines[key]["runtimes"]
137-
for key, baseline in func_baselines.items():
138-
diff = (
139-
float("nan")
140-
if np.isclose(baseline, 0)
141-
else (runtimes[key] - baseline) / baseline
142-
)
143-
assert runtimes[key] < baseline * (
144-
1 + tolerance
145-
), f"{key} is {diff:.2%} slower than the baseline."
135+
if not NEW_BASELINE:
136+
assert key in baselines, f"No basline found for {header}"
137+
func_baselines = baselines[key]["runtimes"]
138+
for key, baseline in func_baselines.items():
139+
diff = (
140+
float("nan")
141+
if np.isclose(baseline, 0)
142+
else (runtimes[key] - baseline) / baseline
143+
)
144+
assert runtimes[key] < baseline * (
145+
1 + tolerance
146+
), f"{key} is {diff:.2%} slower than the baseline."
146147

147148
return test_wrapper
148149

@@ -187,13 +188,13 @@ def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
187188
(
188189
# Test a single SWC cell with both solvers.
189190
pytest.param(1, False, False, 0.0, "jaxley.stone"),
190-
# pytest.param(1, False, False, 0.0, "jax.sparse"),
191-
# # Test a network of SWC cells with both solvers.
192-
# pytest.param(10, False, True, 0.1, "jaxley.stone"),
193-
# pytest.param(10, False, True, 0.1, "jax.sparse"),
194-
# # Test a larger network of smaller neurons with both solvers.
195-
# pytest.param(1000, True, True, 0.001, "jaxley.stone"),
196-
# pytest.param(1000, True, True, 0.001, "jax.sparse"),
191+
pytest.param(1, False, False, 0.0, "jax.sparse"),
192+
# Test a network of SWC cells with both solvers.
193+
pytest.param(10, False, True, 0.1, "jaxley.stone"),
194+
pytest.param(10, False, True, 0.1, "jax.sparse"),
195+
# Test a larger network of smaller neurons with both solvers.
196+
pytest.param(1000, True, True, 0.001, "jaxley.stone"),
197+
pytest.param(1000, True, True, 0.001, "jax.sparse"),
197198
),
198199
)
199200
@compare_to_baseline(baseline_iters=3)
@@ -204,41 +205,37 @@ def test_runtime(
204205
connection_prob: float,
205206
voltage_solver: str,
206207
):
207-
import time
208-
# delta_t = 0.025
209-
# t_max = 100.0
210-
211-
# def simulate(params):
212-
# return jx.integrate(
213-
# net,
214-
# params=params,
215-
# t_max=t_max,
216-
# delta_t=delta_t,
217-
# voltage_solver=voltage_solver,
218-
# )
208+
delta_t = 0.025
209+
t_max = 100.0
210+
211+
def simulate(params):
212+
return jx.integrate(
213+
net,
214+
params=params,
215+
t_max=t_max,
216+
delta_t=delta_t,
217+
voltage_solver=voltage_solver,
218+
)
219219

220220
runtimes = {}
221221

222222
start_time = time.time()
223-
# net, params = build_net(
224-
# num_cells,
225-
# artificial=artificial,
226-
# connect=connect,
227-
# connection_prob=connection_prob,
228-
# )
229-
time.sleep(0.1)
223+
net, params = build_net(
224+
num_cells,
225+
artificial=artificial,
226+
connect=connect,
227+
connection_prob=connection_prob,
228+
)
230229
runtimes["build_time"] = time.time() - start_time
231230

232-
# jitted_simulate = jit(simulate)
231+
jitted_simulate = jit(simulate)
233232

234233
start_time = time.time()
235-
time.sleep(0.2)
236-
# _ = jitted_simulate(params).block_until_ready()
234+
_ = jitted_simulate(params).block_until_ready()
237235
runtimes["compile_time"] = time.time() - start_time
238-
# params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
236+
params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
239237

240238
start_time = time.time()
241-
# _ = jitted_simulate(params).block_until_ready()
242-
time.sleep(0.31)
239+
_ = jitted_simulate(params).block_until_ready()
243240
runtimes["run_time"] = time.time() - start_time
244241
return runtimes # @compare_to_baseline decorator will compare this to the baseline

0 commit comments

Comments
 (0)