@@ -132,17 +132,18 @@ def test_wrapper(**kwargs):
132
132
133
133
append_to_json (fpath_results , header ["test_name" ], header ["input_kwargs" ], runtimes )
134
134
135
- assert key in baselines , f"No basline found for { header } "
136
- func_baselines = baselines [key ]["runtimes" ]
137
- for key , baseline in func_baselines .items ():
138
- diff = (
139
- float ("nan" )
140
- if np .isclose (baseline , 0 )
141
- else (runtimes [key ] - baseline ) / baseline
142
- )
143
- assert runtimes [key ] < baseline * (
144
- 1 + tolerance
145
- ), f"{ key } is { diff :.2%} slower than the baseline."
135
+ if not NEW_BASELINE :
136
+ assert key in baselines , f"No basline found for { header } "
137
+ func_baselines = baselines [key ]["runtimes" ]
138
+ for key , baseline in func_baselines .items ():
139
+ diff = (
140
+ float ("nan" )
141
+ if np .isclose (baseline , 0 )
142
+ else (runtimes [key ] - baseline ) / baseline
143
+ )
144
+ assert runtimes [key ] < baseline * (
145
+ 1 + tolerance
146
+ ), f"{ key } is { diff :.2%} slower than the baseline."
146
147
147
148
return test_wrapper
148
149
@@ -187,13 +188,13 @@ def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
187
188
(
188
189
# Test a single SWC cell with both solvers.
189
190
pytest .param (1 , False , False , 0.0 , "jaxley.stone" ),
190
- # pytest.param(1, False, False, 0.0, "jax.sparse"),
191
- # # Test a network of SWC cells with both solvers.
192
- # pytest.param(10, False, True, 0.1, "jaxley.stone"),
193
- # pytest.param(10, False, True, 0.1, "jax.sparse"),
194
- # # Test a larger network of smaller neurons with both solvers.
195
- # pytest.param(1000, True, True, 0.001, "jaxley.stone"),
196
- # pytest.param(1000, True, True, 0.001, "jax.sparse"),
191
+ pytest .param (1 , False , False , 0.0 , "jax.sparse" ),
192
+ # Test a network of SWC cells with both solvers.
193
+ pytest .param (10 , False , True , 0.1 , "jaxley.stone" ),
194
+ pytest .param (10 , False , True , 0.1 , "jax.sparse" ),
195
+ # Test a larger network of smaller neurons with both solvers.
196
+ pytest .param (1000 , True , True , 0.001 , "jaxley.stone" ),
197
+ pytest .param (1000 , True , True , 0.001 , "jax.sparse" ),
197
198
),
198
199
)
199
200
@compare_to_baseline (baseline_iters = 3 )
@@ -204,41 +205,37 @@ def test_runtime(
204
205
connection_prob : float ,
205
206
voltage_solver : str ,
206
207
):
207
- import time
208
- # delta_t = 0.025
209
- # t_max = 100.0
210
-
211
- # def simulate(params):
212
- # return jx.integrate(
213
- # net,
214
- # params=params,
215
- # t_max=t_max,
216
- # delta_t=delta_t,
217
- # voltage_solver=voltage_solver,
218
- # )
208
+ delta_t = 0.025
209
+ t_max = 100.0
210
+
211
+ def simulate (params ):
212
+ return jx .integrate (
213
+ net ,
214
+ params = params ,
215
+ t_max = t_max ,
216
+ delta_t = delta_t ,
217
+ voltage_solver = voltage_solver ,
218
+ )
219
219
220
220
runtimes = {}
221
221
222
222
start_time = time .time ()
223
- # net, params = build_net(
224
- # num_cells,
225
- # artificial=artificial,
226
- # connect=connect,
227
- # connection_prob=connection_prob,
228
- # )
229
- time .sleep (0.1 )
223
+ net , params = build_net (
224
+ num_cells ,
225
+ artificial = artificial ,
226
+ connect = connect ,
227
+ connection_prob = connection_prob ,
228
+ )
230
229
runtimes ["build_time" ] = time .time () - start_time
231
230
232
- # jitted_simulate = jit(simulate)
231
+ jitted_simulate = jit (simulate )
233
232
234
233
start_time = time .time ()
235
- time .sleep (0.2 )
236
- # _ = jitted_simulate(params).block_until_ready()
234
+ _ = jitted_simulate (params ).block_until_ready ()
237
235
runtimes ["compile_time" ] = time .time () - start_time
238
- # params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
236
+ params [0 ]["radius" ] = params [0 ]["radius" ].at [0 ].set (0.5 )
239
237
240
238
start_time = time .time ()
241
- # _ = jitted_simulate(params).block_until_ready()
242
- time .sleep (0.31 )
239
+ _ = jitted_simulate (params ).block_until_ready ()
243
240
runtimes ["run_time" ] = time .time () - start_time
244
241
return runtimes # @compare_to_baseline decorator will compare this to the baseline
0 commit comments