Skip to content

Commit

Permalink
Add runtime tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Oct 29, 2024
1 parent de1e9f2 commit 4696db6
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest tests/ --cov=jaxley --cov-report=xml
pytest tests/ -m "not run_locally" --cov=jaxley --cov-report=xml
121 changes: 121 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import time

import numpy as np
import pytest
from jax import jit

import jaxley as jx
from jaxley.channels import HH
from jaxley.connect import sparse_connect
from jaxley.synapses import IonotropicSynapse


def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
_ = np.random.seed(1) # For sparse connectivity matrix.

if artificial:
comp = jx.Compartment()
branch = jx.Branch(comp, 2)
depth = 3
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
cell = jx.Cell(branch, parents=parents)
else:
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell = jx.read_swc(fname, nseg=4)
net = jx.Network([cell for _ in range(num_cells)])

# Channels.
net.insert(HH())

# Synapses.
if connect:
sparse_connect(
net.cell("all"), net.cell("all"), IonotropicSynapse(), connection_prob
)

# Recordings.
net[0, 1, 0].record(verbose=False)

# Trainables.
net.make_trainable("radius", verbose=False)
params = net.get_parameters()

net.to_jax()
return net, params


@pytest.mark.run_locally
@pytest.mark.parametrize(
"num_cells, artificial, connect, connection_prob, voltage_solver, identifier",
(
# Test a single SWC cell with both solvers.
pytest.param(1, False, False, 0.0, "jaxley.stone", 0),
pytest.param(1, False, False, 0.0, "jax.sparse", 1),
# Test a network of SWC cells with both solvers.
pytest.param(10, False, True, 0.1, "jaxley.stone", 2),
pytest.param(10, False, True, 0.1, "jax.sparse", 3),
# Test a larger network of smaller neurons with both solvers.
pytest.param(1000, True, True, 0.001, "jaxley.stone", 4),
pytest.param(1000, True, True, 0.001, "jax.sparse", 5),
),
)
def test_runtime(
num_cells: int,
artificial: bool,
connect: bool,
connection_prob: float,
voltage_solver: str,
identifier: int,
):
delta_t = 0.025
t_max = 100.0

net, params = build_net(
num_cells,
artificial=artificial,
connect=connect,
connection_prob=connection_prob,
)

def simulate(params):
return jx.integrate(
net,
params=params,
t_max=t_max,
delta_t=delta_t,
voltage_solver=voltage_solver,
)

jitted_simulate = jit(simulate)

start_time = time.time()
_ = jitted_simulate(params).block_until_ready()
compile_time = time.time() - start_time

params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
start_time = time.time()
_ = jitted_simulate(params).block_until_ready()
run_time = time.time() - start_time

compile_times = {
0: 16.858529806137085,
1: 0.8063809871673584,
2: 5.4792890548706055,
3: 6.175129175186157,
4: 2.755805015563965,
5: 13.303060293197632,
}
run_times = {
0: 0.08291006088256836,
1: 0.596994161605835,
2: 0.8518729209899902,
3: 5.746302127838135,
4: 1.3585789203643799,
5: 12.48916506767273,
}

tolerance = 1.2
assert compile_time < compile_times[identifier] * tolerance
assert run_time < run_times[identifier] * tolerance

0 comments on commit 4696db6

Please sign in to comment.