Skip to content

Commit

Permalink
enh: add slow markers and add remaining fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Nov 15, 2024
1 parent bc5a3cc commit a5303ed
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 123 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ dev = [
"optax",
]

[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (T > 10s)",
]

[tool.isort]
profile = "black"

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def SimpleMorphCell():

cells = {}

def cell_w_params(fname=None, nseg=2, max_branch_len=2_000.0, copy=True):
def cell_w_params(fname=None, nseg=1, max_branch_len=2_000.0, copy=True):
fname = default_fname if fname is None else fname
if key := (fname, nseg, max_branch_len) not in cells:
cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True)
Expand Down
77 changes: 27 additions & 50 deletions tests/jaxley_identical/test_basic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_compartment(voltage_solver: str):
def test_compartment(voltage_solver, SimpleComp, SimpleBranch, SimpleCell, SimpleNet):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)
Expand All @@ -48,7 +48,7 @@ def test_compartment(voltage_solver: str):
)

# Test compartment.
comp = jx.Compartment()
comp = SimpleComp()
comp.insert(HH())
comp.record()
comp.stimulate(current)
Expand All @@ -57,7 +57,7 @@ def test_compartment(voltage_solver: str):
assert max_error <= tolerance, f"Compartment error is {max_error} > {tolerance}"

# Test branch of a single compartment.
branch = jx.Branch()
branch = SimpleBranch(nseg=1)
branch.insert(HH())
branch.record()
branch.stimulate(current)
Expand All @@ -66,7 +66,7 @@ def test_compartment(voltage_solver: str):
assert max_error <= tolerance, f"Branch error is {max_error} > {tolerance}"

# Test cell of a single compartment.
cell = jx.Cell()
cell = SimpleCell(1, 1)
cell.insert(HH())
cell.record()
cell.stimulate(current)
Expand All @@ -75,8 +75,7 @@ def test_compartment(voltage_solver: str):
assert max_error <= tolerance, f"Cell error is {max_error} > {tolerance}"

# Test net of a single compartment.
cell = jx.Cell()
net = jx.Network([cell])
net = SimpleNet(1, 1, 1)
net.insert(HH())
net.record()
net.stimulate(current)
Expand All @@ -86,14 +85,12 @@ def test_compartment(voltage_solver: str):


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_branch(voltage_solver: str):
nseg_per_branch = 2
def test_branch(voltage_solver, SimpleBranch):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
branch = SimpleBranch(2)
branch.insert(HH())
branch.loc(0.0).record()
branch.loc(0.0).stimulate(current)
Expand Down Expand Up @@ -122,13 +119,12 @@ def test_branch(voltage_solver: str):
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


def test_branch_fwd_euler_uneven_radiuses():
def test_branch_fwd_euler_uneven_radiuses(SimpleBranch):
dt = 0.025 # ms
t_max = 10.0 # ms
current = jx.step_current(0.5, 1.0, 2.0, dt, t_max)

comp = jx.Compartment()
branch = jx.Branch(comp, 8)
branch = SimpleBranch(8)
branch.set("axial_resistivity", 500.0)

rands1 = np.linspace(20, 300, 8)
Expand Down Expand Up @@ -161,18 +157,12 @@ def test_branch_fwd_euler_uneven_radiuses():


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_cell(voltage_solver: str):
nseg_per_branch = 2
def test_cell(voltage_solver, SimpleCell):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell = jx.Cell([branch for _ in range(len(parents))], parents=parents)
cell = SimpleCell(3, 2)
cell.insert(HH())
cell.branch(1).loc(0.0).record()
cell.branch(1).loc(0.0).stimulate(current)
Expand Down Expand Up @@ -201,17 +191,16 @@ def test_cell(voltage_solver: str):
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


def test_cell_unequal_compartment_number():
def test_cell_unequal_compartment_number(SimpleBranch):
"""Tests a cell where every branch has a different number of compartments."""
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.1, dt, t_max)

comp = jx.Compartment()
branch1 = jx.Branch(comp, nseg=1)
branch2 = jx.Branch(comp, nseg=2)
branch3 = jx.Branch(comp, nseg=3)
branch4 = jx.Branch(comp, nseg=4)
branch1 = SimpleBranch(nseg=1)
branch2 = SimpleBranch(nseg=2)
branch3 = SimpleBranch(nseg=3)
branch4 = SimpleBranch(nseg=4)
cell = jx.Cell([branch1, branch2, branch3, branch4], parents=[-1, 0, 0, 1])
cell.set("axial_resistivity", 10_000.0)
cell.insert(HH())
Expand All @@ -236,40 +225,32 @@ def test_cell_unequal_compartment_number():


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_net(voltage_solver: str):
nseg_per_branch = 2
def test_net(voltage_solver, SimpleNet):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell1 = jx.Cell([branch for _ in range(len(parents))], parents=parents)
cell2 = jx.Cell([branch for _ in range(len(parents))], parents=parents)
net = SimpleNet(2, 3, 2)

network = jx.Network([cell1, cell2])
connect(
network.cell(0).branch(0).loc(0.0),
network.cell(1).branch(0).loc(0.0),
net.cell(0).branch(0).loc(0.0),
net.cell(1).branch(0).loc(0.0),
IonotropicSynapse(),
)
network.insert(HH())
net.insert(HH())

for cell_ind in range(2):
network.cell(cell_ind).branch(1).loc(0.0).record()
net.cell(cell_ind).branch(1).loc(0.0).record()

for stim_ind in range(2):
network.cell(stim_ind).branch(1).loc(0.0).stimulate(current)
net.cell(stim_ind).branch(1).loc(0.0).stimulate(current)

area = 2 * pi * 10.0 * 1.0
point_process_to_dist_factor = 100_000.0 / area
network.IonotropicSynapse.set(
net.IonotropicSynapse.set(
"IonotropicSynapse_gS", 0.5 / point_process_to_dist_factor
)
voltages = jx.integrate(network, delta_t=dt, voltage_solver=voltage_solver)
voltages = jx.integrate(net, delta_t=dt, voltage_solver=voltage_solver)

voltages_300724 = jnp.asarray(
[
Expand Down Expand Up @@ -308,12 +289,8 @@ def test_net(voltage_solver: str):


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_complex_net(voltage_solver: str):
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])

net = jx.Network([cell for _ in range(7)])
def test_complex_net(voltage_solver, SimpleNet):
net = SimpleNet(7, 5, 4)
net.insert(HH())

_ = np.random.seed(0)
Expand Down
10 changes: 4 additions & 6 deletions tests/jaxley_identical/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import jax.numpy as jnp
import numpy as np
import pytest
from jax import value_and_grad

import jaxley as jx
Expand All @@ -22,12 +23,9 @@
from jaxley.synapses import IonotropicSynapse, TestSynapse


def test_network_grad():
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])

net = jx.Network([cell for _ in range(7)])
@pytest.mark.slow
def test_network_grad(SimpleNet):
net = SimpleNet(7, 5, 4)
net.insert(HH())

_ = np.random.seed(0)
Expand Down
62 changes: 23 additions & 39 deletions tests/jaxley_identical/test_radius_and_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_radius_and_length_compartment(voltage_solver: str):
def test_radius_and_length_compartment(voltage_solver, SimpleComp):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

comp = jx.Compartment()
comp = SimpleComp()

np.random.seed(1)
comp.set("length", 5 * np.random.rand(1))
Expand Down Expand Up @@ -63,14 +63,12 @@ def test_radius_and_length_compartment(voltage_solver: str):


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_radius_and_length_branch(voltage_solver: str):
nseg_per_branch = 2
def test_radius_and_length_branch(voltage_solver, SimpleBranch):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
branch = SimpleBranch(nseg=2)

np.random.seed(1)
branch.set("length", np.flip(5 * np.random.rand(2)))
Expand Down Expand Up @@ -105,19 +103,13 @@ def test_radius_and_length_branch(voltage_solver: str):


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_radius_and_length_cell(voltage_solver: str):
nseg_per_branch = 2
def test_radius_and_length_cell(voltage_solver, SimpleCell):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
num_branches = len(parents)

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell = jx.Cell([branch for _ in range(len(parents))], parents=parents)
num_branches = 3
cell = SimpleCell(num_branches, nseg=2)

np.random.seed(1)
rands1 = 5 * np.random.rand(2 * num_branches)
Expand Down Expand Up @@ -155,57 +147,49 @@ def test_radius_and_length_cell(voltage_solver: str):


@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_radius_and_length_net(voltage_solver: str):
nseg_per_branch = 2
def test_radius_and_length_net(voltage_solver, SimpleNet):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)

depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
num_branches = len(parents)

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell1 = jx.Cell([branch for _ in range(len(parents))], parents=parents)
cell2 = jx.Cell([branch for _ in range(len(parents))], parents=parents)
num_branches = 3
net = SimpleNet(2, num_branches, 2)

np.random.seed(1)
rands1 = 5 * np.random.rand(2 * num_branches)
rands2 = np.random.rand(2 * num_branches)
for b in range(num_branches):
cell1.branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2]))
cell1.branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2]))
net.cell(0).branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2]))
net.cell(0).branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2]))

np.random.seed(2)
rands1 = 5 * np.random.rand(2 * num_branches)
rands2 = np.random.rand(2 * num_branches)
for b in range(num_branches):
cell2.branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2]))
cell2.branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2]))
net.cell(1).branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2]))
net.cell(1).branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2]))

network = jx.Network([cell1, cell2])
connect(
network.cell(0).branch(0).loc(0.0),
network.cell(1).branch(0).loc(0.0),
net.cell(0).branch(0).loc(0.0),
net.cell(1).branch(0).loc(0.0),
IonotropicSynapse(),
)
network.insert(HH())
net.insert(HH())

# first cell, 0-eth branch, 0-st compartment because loc=0.0
radius_post = network[1, 0, 0].nodes["radius"].item()
lenght_post = network[1, 0, 0].nodes["length"].item()
radius_post = net[1, 0, 0].nodes["radius"].item()
lenght_post = net[1, 0, 0].nodes["length"].item()
area = 2 * pi * lenght_post * radius_post
point_process_to_dist_factor = 100_000.0 / area
network.set("IonotropicSynapse_gS", 0.5 / point_process_to_dist_factor)
net.set("IonotropicSynapse_gS", 0.5 / point_process_to_dist_factor)

for cell_ind in range(2):
network.cell(cell_ind).branch(1).loc(0.0).record()
net.cell(cell_ind).branch(1).loc(0.0).record()

for stim_ind in range(2):
network.cell(stim_ind).branch(1).loc(0.0).stimulate(current)
net.cell(stim_ind).branch(1).loc(0.0).stimulate(current)

voltages = jx.integrate(network, delta_t=dt, voltage_solver=voltage_solver)
voltages = jx.integrate(net, delta_t=dt, voltage_solver=voltage_solver)

voltages_300724 = jnp.asarray(
[
Expand Down
Loading

0 comments on commit a5303ed

Please sign in to comment.