Skip to content

Commit 8089980

Browse files
Add runtime tests
1 parent de1e9f2 commit 8089980

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ jobs:
3939
- name: Test with pytest
4040
run: |
4141
pip install pytest pytest-cov
42-
pytest tests/ --cov=jaxley --cov-report=xml
42+
pytest tests/ -m "not runtime" --cov=jaxley --cov-report=xml

tests/test_runtime.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import time
3+
4+
import numpy as np
5+
import pytest
6+
from jax import jit
7+
8+
import jaxley as jx
9+
from jaxley.channels import HH
10+
from jaxley.connect import sparse_connect
11+
from jaxley.synapses import IonotropicSynapse
12+
13+
14+
def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
15+
_ = np.random.seed(1) # For sparse connectivity matrix.
16+
17+
if artificial:
18+
comp = jx.Compartment()
19+
branch = jx.Branch(comp, 2)
20+
depth = 3
21+
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
22+
cell = jx.Cell(branch, parents=parents)
23+
else:
24+
dirname = os.path.dirname(__file__)
25+
fname = os.path.join(dirname, "swc_files", "morph.swc")
26+
cell = jx.read_swc(fname, nseg=4)
27+
net = jx.Network([cell for _ in range(num_cells)])
28+
29+
# Channels.
30+
net.insert(HH())
31+
32+
# Synapses.
33+
if connect:
34+
sparse_connect(
35+
net.cell("all"), net.cell("all"), IonotropicSynapse(), connection_prob
36+
)
37+
38+
# Recordings.
39+
net[0, 1, 0].record(verbose=False)
40+
41+
# Trainables.
42+
net.make_trainable("radius", verbose=False)
43+
params = net.get_parameters()
44+
45+
net.to_jax()
46+
return net, params
47+
48+
49+
@pytest.mark.runtime
50+
@pytest.mark.parametrize(
51+
"num_cells, artificial, connect, connection_prob, voltage_solver, identifier",
52+
(
53+
# Test a single SWC cell with both solvers.
54+
pytest.param(1, False, False, 0.0, "jaxley.stone", 0),
55+
pytest.param(1, False, False, 0.0, "jax.sparse", 1),
56+
# Test a network of SWC cells with both solvers.
57+
pytest.param(10, False, True, 0.1, "jaxley.stone", 2),
58+
pytest.param(10, False, True, 0.1, "jax.sparse", 3),
59+
# Test a larger network of smaller neurons with both solvers.
60+
pytest.param(1000, True, True, 0.001, "jaxley.stone", 4),
61+
pytest.param(1000, True, True, 0.001, "jax.sparse", 5),
62+
),
63+
)
64+
def test_runtime(
65+
num_cells: int,
66+
artificial: bool,
67+
connect: bool,
68+
connection_prob: float,
69+
voltage_solver: str,
70+
identifier: int,
71+
):
72+
delta_t = 0.025
73+
t_max = 100.0
74+
75+
net, params = build_net(
76+
num_cells,
77+
artificial=artificial,
78+
connect=connect,
79+
connection_prob=connection_prob,
80+
)
81+
82+
def simulate(params):
83+
return jx.integrate(
84+
net,
85+
params=params,
86+
t_max=t_max,
87+
delta_t=delta_t,
88+
voltage_solver=voltage_solver,
89+
)
90+
91+
jitted_simulate = jit(simulate)
92+
93+
start_time = time.time()
94+
_ = jitted_simulate(params).block_until_ready()
95+
compile_time = time.time() - start_time
96+
97+
params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
98+
start_time = time.time()
99+
_ = jitted_simulate(params).block_until_ready()
100+
run_time = time.time() - start_time
101+
102+
compile_times = {
103+
0: 16.858529806137085,
104+
1: 0.8063809871673584,
105+
2: 5.4792890548706055,
106+
3: 6.175129175186157,
107+
4: 2.755805015563965,
108+
5: 13.303060293197632,
109+
}
110+
run_times = {
111+
0: 0.08291006088256836,
112+
1: 0.596994161605835,
113+
2: 0.8518729209899902,
114+
3: 5.746302127838135,
115+
4: 1.3585789203643799,
116+
5: 12.48916506767273,
117+
}
118+
119+
tolerance = 1.2
120+
assert compile_time < compile_times[identifier] * tolerance
121+
assert run_time < run_times[identifier] * tolerance

0 commit comments

Comments
 (0)