From 4696db6c9e881b1fead541543f1c5d5b44d0d6cd Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Tue, 29 Oct 2024 17:32:37 +0100 Subject: [PATCH] Add runtime tests --- .github/workflows/tests.yml | 2 +- tests/test_runtime.py | 121 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 tests/test_runtime.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3eb90b0a..f0562742 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,4 +39,4 @@ jobs: - name: Test with pytest run: | pip install pytest pytest-cov - pytest tests/ --cov=jaxley --cov-report=xml + pytest tests/ -m "not run_locally" --cov=jaxley --cov-report=xml diff --git a/tests/test_runtime.py b/tests/test_runtime.py new file mode 100644 index 00000000..fb37d256 --- /dev/null +++ b/tests/test_runtime.py @@ -0,0 +1,121 @@ +import os +import time + +import numpy as np +import pytest +from jax import jit + +import jaxley as jx +from jaxley.channels import HH +from jaxley.connect import sparse_connect +from jaxley.synapses import IonotropicSynapse + + +def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0): + _ = np.random.seed(1) # For sparse connectivity matrix. + + if artificial: + comp = jx.Compartment() + branch = jx.Branch(comp, 2) + depth = 3 + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + cell = jx.Cell(branch, parents=parents) + else: + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", "morph.swc") + cell = jx.read_swc(fname, nseg=4) + net = jx.Network([cell for _ in range(num_cells)]) + + # Channels. + net.insert(HH()) + + # Synapses. + if connect: + sparse_connect( + net.cell("all"), net.cell("all"), IonotropicSynapse(), connection_prob + ) + + # Recordings. + net[0, 1, 0].record(verbose=False) + + # Trainables. + net.make_trainable("radius", verbose=False) + params = net.get_parameters() + + net.to_jax() + return net, params + + +@pytest.mark.run_locally +@pytest.mark.parametrize( + "num_cells, artificial, connect, connection_prob, voltage_solver, identifier", + ( + # Test a single SWC cell with both solvers. + pytest.param(1, False, False, 0.0, "jaxley.stone", 0), + pytest.param(1, False, False, 0.0, "jax.sparse", 1), + # Test a network of SWC cells with both solvers. + pytest.param(10, False, True, 0.1, "jaxley.stone", 2), + pytest.param(10, False, True, 0.1, "jax.sparse", 3), + # Test a larger network of smaller neurons with both solvers. + pytest.param(1000, True, True, 0.001, "jaxley.stone", 4), + pytest.param(1000, True, True, 0.001, "jax.sparse", 5), + ), +) +def test_runtime( + num_cells: int, + artificial: bool, + connect: bool, + connection_prob: float, + voltage_solver: str, + identifier: int, +): + delta_t = 0.025 + t_max = 100.0 + + net, params = build_net( + num_cells, + artificial=artificial, + connect=connect, + connection_prob=connection_prob, + ) + + def simulate(params): + return jx.integrate( + net, + params=params, + t_max=t_max, + delta_t=delta_t, + voltage_solver=voltage_solver, + ) + + jitted_simulate = jit(simulate) + + start_time = time.time() + _ = jitted_simulate(params).block_until_ready() + compile_time = time.time() - start_time + + params[0]["radius"] = params[0]["radius"].at[0].set(0.5) + start_time = time.time() + _ = jitted_simulate(params).block_until_ready() + run_time = time.time() - start_time + + compile_times = { + 0: 16.858529806137085, + 1: 0.8063809871673584, + 2: 5.4792890548706055, + 3: 6.175129175186157, + 4: 2.755805015563965, + 5: 13.303060293197632, + } + run_times = { + 0: 0.08291006088256836, + 1: 0.596994161605835, + 2: 0.8518729209899902, + 3: 5.746302127838135, + 4: 1.3585789203643799, + 5: 12.48916506767273, + } + + tolerance = 1.2 + assert compile_time < compile_times[identifier] * tolerance + assert run_time < run_times[identifier] * tolerance