From c19448238eb550122469204b359fcc287e662300 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sun, 10 Nov 2024 22:26:14 +0100 Subject: [PATCH 01/15] wip: updating tests w fixtures v1 still wip --- tests/conftest.py | 86 ++++++++++++ tests/test_moving.py | 67 ++++----- tests/test_optimize.py | 8 +- tests/test_pickle.py | 5 + tests/test_plotting_api.py | 118 ++++++++-------- tests/test_record_and_stimulate.py | 29 +--- tests/test_set_ncomp.py | 77 +++++------ tests/test_solver.py | 11 +- tests/test_swc.py | 4 +- tests/test_syn.py | 7 +- tests/test_synapse_indexing.py | 30 ++-- tests/test_transforms.py | 6 +- tests/test_viewing.py | 213 +++++++++++++---------------- 13 files changed, 324 insertions(+), 337 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..03963314 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,86 @@ +import os +from copy import deepcopy + +import jax.numpy as jnp +import numpy as np +import pytest + +import jaxley as jx +from jaxley.channels import HH +from jaxley.synapses import IonotropicSynapse, TestSynapse + + +@pytest.fixture(scope="session") +def SimpleComp(): + comp = jx.Compartment() + + def get_comp(copy=False): + return deepcopy(comp) if copy else comp + + yield get_comp + comp = None + + +@pytest.fixture(scope="session") +def SimpleBranch(SimpleComp): + branches = {} + + def branch_w_shape(nseg, copy=False): + if nseg not in branches: + branches[nseg] = jx.Branch([SimpleComp()] * nseg) + return deepcopy(branches[nseg]) if copy else branches[nseg] + + yield branch_w_shape + branches = {} + + +@pytest.fixture(scope="session") +def SimpleCell(SimpleBranch): + cells = {} + + def cell_w_shape(nbranches, nseg_per_branch, copy=False): + if key := (nbranches, nseg_per_branch) not in cells: + parents = [-1] + depth = 0 + while nbranches > len(parents): + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + depth += 1 + parents = parents[:nbranches] + cells[key] = jx.Cell([SimpleBranch(nseg_per_branch)] * nbranches, parents) + return deepcopy(cells[key]) if copy else cells[key] + + yield cell_w_shape + cells = {} + + +@pytest.fixture(scope="session") +def SimpleNet(SimpleCell): + nets = {} + + def net_w_shape(n_cells, nbranches, nseg_per_branch, connect=False, copy=False): + if key := (n_cells, nbranches, nseg_per_branch, connect) not in nets: + net = jx.Network([SimpleCell(nbranches, nseg_per_branch)] * n_cells) + if connect: + jx.connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) + nets[key] = net + return deepcopy(nets[key]) if copy else nets[key] + + yield net_w_shape + nets = {} + + +@pytest.fixture(scope="session") +def SimpleMorphCell(): + dirname = os.path.dirname(__file__) + default_fname = os.path.join(dirname, "swc_files", "morph.swc") # n120 + + cells = {} + + def cell_w_params(fname=None, nseg=2, max_branch_len=2_000.0, copy=False): + fname = default_fname if fname is None else fname + if key := (fname, nseg, max_branch_len) not in cells: + cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True) + return deepcopy(cells[key]) if copy else cells[key] + + yield cell_w_params + cells = {} diff --git a/tests/test_moving.py b/tests/test_moving.py index 75e77135..d8bb3d9d 100644 --- a/tests/test_moving.py +++ b/tests/test_moving.py @@ -15,13 +15,9 @@ import jaxley as jx -def test_move_cell(): - nseg = 4 - +def test_move_cell(SimpleBranch, SimpleCell): # Test move on a cell with compute_xyz() - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + cell = SimpleCell(5, nseg=4, copy=True) cell.compute_xyz() cell.move(20.0, 30.0, 5.0) assert cell.xyzr[0][0, 0] == 20.0 @@ -29,8 +25,7 @@ def test_move_cell(): assert cell.xyzr[0][0, 2] == 5.0 # Test move_to on a cell that starts with a specified xyzr - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) + branch = SimpleBranch(nseg=4, copy=True) cell = jx.Cell( branch, parents=[-1], @@ -50,11 +45,8 @@ def test_move_cell(): assert cell.xyzr[0][0, 3] == 10.0 -def test_move_network(): - nseg = 2 - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) - cell = jx.Cell([branch, branch, branch], parents=[-1, 0, 0]) +def test_move_network(SimpleCell): + cell = SimpleCell(3, 3, copy=True) cell.compute_xyz() net = jx.Network([cell, cell, cell]) net.move(20.0, 30.0, 5.0) @@ -64,19 +56,15 @@ def test_move_network(): assert net.xyzr[i][0, 2] == 5.0 -def test_move_to_cell(): - nseg = 4 - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) +def test_move_to_cell(SimpleBranch, SimpleCell): + cell = SimpleCell(5, 4, copy=True) cell.compute_xyz() cell.move_to(20.0, 30.0, 5.0) assert cell.xyzr[0][0, 0] == 20.0 assert cell.xyzr[0][0, 1] == 30.0 assert cell.xyzr[0][0, 2] == 5.0 - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) + branch = SimpleBranch(nseg=4) cell = jx.Cell( branch, parents=[-1], @@ -96,13 +84,9 @@ def test_move_to_cell(): assert cell.xyzr[0][0, 3] == 10.0 -def test_move_to_network(): - nseg = 4 - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) - cell = jx.Cell([branch, branch, branch], parents=[-1, 0, 0]) - cell.compute_xyz() - net = jx.Network([cell, cell, cell]) +def test_move_to_network(SimpleNet): + net = SimpleNet(3, 3, 4, copy=True) + net.compute_xyz() net.move_to(10.0, 20.0, 30.0) # Branch 0 of cell 0 assert net.xyzr[0][0, 0] == 10.0 @@ -114,14 +98,11 @@ def test_move_to_network(): assert net.xyzr[3][0, 2] == 30.0 -def test_move_to_arrays(): +def test_move_to_arrays(SimpleNet): """Test with network""" nseg = 4 - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) - cell = jx.Cell([branch, branch, branch], parents=[-1, 0, 0]) - cell.compute_xyz() - net = jx.Network([cell, cell, cell]) + net = SimpleNet(3, 3, nseg, copy=True) + net.compute_xyz() x_coords = np.array([10.0, 20.0, 30.0]) y_coords = np.array([5.0, 15.0, 25.0]) z_coords = np.array([1.0, 2.0, 3.0]) @@ -135,12 +116,9 @@ def test_move_to_arrays(): assert net.xyzr[6][0, 1] == 25.0 -def test_move_to_cellview(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=2) - cell = jx.Cell([branch, branch, branch], parents=[-1, 0, 0]) - cell.compute_xyz() - net = jx.Network([cell for _ in range(3)]) +def test_move_to_cellview(net): + net = net(3, 3, 2, copy=True) + net.compute_xyz() # Test with float input net.cell(0).move_to(50.0, 3.0, 40.0) @@ -149,7 +127,8 @@ def test_move_to_cellview(): assert net.xyzr[0][0, 2] == 40.0 # Test with array input - net = jx.Network([cell for _ in range(4)]) + net = net(4, 3, 2, copy=True) + net.compute_xyz() testx = np.array([1.0, 2.0, 3.0]) testy = np.array([4.0, 5.0, 6.0]) testz = np.array([7.0, 8.0, 9.0]) @@ -160,12 +139,12 @@ def test_move_to_cellview(): assert net.xyzr[9][0, 0] == 0.0 -def test_move_to_swc_cell(): +def test_move_to_swc_cell(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = jx.read_swc(fname, nseg=4) - cell2 = jx.read_swc(fname, nseg=4) - cell3 = jx.read_swc(fname, nseg=4) + cell1 = SimpleMorphCell(fname, nseg=4, copy=True) + cell2 = SimpleMorphCell(fname, nseg=4, copy=True) + cell3 = SimpleMorphCell(fname, nseg=4, copy=True) # Try move_to on a cell cell1.move_to(10.0, 20.0, 30.0) diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 9da2b5df..52cc7ea6 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -17,9 +17,9 @@ from jaxley.optimize.utils import l2_norm -def test_type_optimizer_api(): +def test_type_optimizer_api(SimpleComp): """Tests whether optimization recovers a ground truth parameter set.""" - comp = jx.Compartment() + comp = SimpleComp(copy=True) comp.insert(HH()) comp.record() comp.stimulate(jx.step_current(0.1, 3.0, 0.1, 0.025, 5.0)) @@ -48,9 +48,9 @@ def loss_fn(params): opt_params = optax.apply_updates(opt_params, updates) -def test_type_optimizer(): +def test_type_optimizer(SimpleComp): """Tests whether optimization recovers a ground truth parameter set.""" - comp = jx.Compartment() + comp = SimpleComp(copy=True) comp.insert(HH()) comp.record() comp.stimulate(jx.step_current(0.1, 3.0, 0.1, 0.025, 5.0)) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 49b58f1e..a8652312 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -29,4 +29,9 @@ def test_pickle(module): pickled = pickle.dumps(module) unpickled = pickle.loads(pickled) + + view = module.select(0) + pickled = pickle.dumps(view) + unpickled = pickle.loads(pickled) + # assert module == unpickled # TODO: implement __eq__ for all classes diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index bef58daa..57e7a121 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -19,60 +19,30 @@ from jaxley.synapses import IonotropicSynapse -@pytest.fixture(scope="module") -def comp() -> jx.Compartment: - comp = jx.Compartment() - comp.compute_xyz() - return comp - - -@pytest.fixture(scope="module") -def branch(comp) -> jx.Branch: - branch = jx.Branch(comp, 4) - branch.compute_xyz() - return branch - - -@pytest.fixture(scope="module") -def cell(branch) -> jx.Cell: - cell = jx.Cell(branch, [-1, 0, 0, 1, 1]) - cell.branch(0).set_ncomp(3) - cell.compute_xyz() - return cell - - -@pytest.fixture(scope="module") -def simple_net(cell) -> jx.Network: - net = jx.Network([cell] * 4) - net.compute_xyz() - return net - - -@pytest.fixture(scope="module") -def morph_cell() -> jx.Cell: - morph_cell = jx.read_swc( - os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"), - nseg=1, - ) - morph_cell.branch(0).set_ncomp(2) - return morph_cell - +def test_cell(SimpleMorphCell): + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", "morph.swc") + cell = SimpleMorphCell(fname, nseg=4) -def test_cell(morph_cell): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - ax = morph_cell.vis(ax=ax) - ax = morph_cell.branch([0, 1, 2]).vis(ax=ax, col="r") - ax = morph_cell.branch(1).loc(0.9).vis(ax=ax, col="b") + ax = cell.vis(ax=ax) + ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r") + ax = cell.branch(1).loc(0.9).vis(ax=ax, col="b") # Plot 2. - morph_cell.branch(0).add_to_group("soma") - morph_cell.branch(1).add_to_group("soma") - ax = morph_cell.soma.vis() - - -def test_network(morph_cell): - net = jx.Network([morph_cell, morph_cell, morph_cell]) + cell.branch(0).add_to_group("soma") + cell.branch(1).add_to_group("soma") + ax = cell.soma.vis() + +def test_network(SimpleMorphCell): + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", "morph.swc") + cell1 = SimpleMorphCell(fname, nseg=4) + cell2 = SimpleMorphCell(fname, nseg=4) + cell3 = SimpleMorphCell(fname, nseg=4) + + net = jx.Network([cell1, cell2, cell3]) connect( net.cell(0).branch(0).loc(0.0), net.cell(1).branch(0).loc(0.0), @@ -110,7 +80,11 @@ def test_network(morph_cell): ax = net.excitatory.vis() -def test_vis_networks_built_from_scratch(comp, branch, cell): +def test_vis_networks_built_from_scartch(SimpleComp, SimpleBranch, SimpleCell): + comp = SimpleComp(copy=True) + branch = SimpleBranch(4) + cell = SimpleCell(5, 3) + net = jx.Network([cell, cell]) connect( net.cell(0).branch(0).loc(0.0), @@ -135,15 +109,25 @@ def test_vis_networks_built_from_scratch(comp, branch, cell): # Plot 3. _, ax = plt.subplots(1, 1, figsize=(3, 3)) + comp.compute_xyz() ax = comp.vis(ax=ax) # Plot 4. _, ax = plt.subplots(1, 1, figsize=(3, 3)) + branch.compute_xyz() ax = branch.vis(ax=ax) -def test_mixed_network(morph_cell, cell): - net = jx.Network([morph_cell, cell]) +def test_mixed_network(SimpleMorphCell): + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", "morph.swc") + cell1 = SimpleMorphCell(fname, nseg=4) + + comp = jx.Compartment() + branch = jx.Branch(comp, 4) + cell2 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + + net = jx.Network([cell1, cell2]) connect( net.cell(0).branch(0).loc(0.0), net.cell(1).branch(0).loc(0.0), @@ -160,9 +144,9 @@ def test_mixed_network(morph_cell, cell): net.cell(1).move(0, -800) net.rotate(180) - before_xyzrs = deepcopy(net.xyzr[len(morph_cell.xyzr) :]) + before_xyzrs = deepcopy(net.xyzr[len(cell1.xyzr) :]) net.cell(1).rotate(90) - after_xyzrs = net.xyzr[len(morph_cell.xyzr) :] + after_xyzrs = net.xyzr[len(cell1.xyzr) :] # Test that rotation worked as expected. for b, a in zip(before_xyzrs, after_xyzrs): assert np.allclose(b[:, 0], -a[:, 1], atol=1e-6) @@ -171,24 +155,32 @@ def test_mixed_network(morph_cell, cell): _ = net.vis(detail="full") -def test_volume_plotting_2d(comp, branch, cell, simple_net, morph_cell): +def test_volume_plotting(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): + comp = SimpleComp() + branch = SimpleBranch(4) + cell = SimpleCell(3, 4) + net = SimpleNet(2,3,4) + for module in [comp, branch, cell, net]: + module.compute_xyz() + + morph_cell = jx.read_swc( + os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"), + nseg=1, + ) + fig, ax = plt.subplots() - for module in [comp, branch, cell, simple_net, morph_cell]: + for module in [comp, branch, cell, net, morph_cell]: module.vis(type="comp", ax=ax, morph_plot_kwargs={"resolution": 6}) plt.close(fig) - -def test_volume_plotting_3d(comp, branch, cell, simple_net, morph_cell): # test 3D plotting - for module in [comp, branch, cell, simple_net, morph_cell]: + for module in [comp, branch, cell, net, morph_cell]: module.vis(type="comp", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6}) plt.close() - -def test_morph_plotting(morph_cell): # test morph plotting (does not work if no radii in xyzr) - morph_cell.vis(type="morph", morph_plot_kwargs={"resolution": 6}) + morph_cell.vis(type="morph") morph_cell.branch(1).vis( type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6} ) # plotting whole thing takes too long - plt.close() + plt.close() \ No newline at end of file diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 47bd5bfd..5d476add 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -16,16 +16,9 @@ from jaxley.synapses import IonotropicSynapse, TestSynapse -def test_record_and_stimulate_api(): +def test_record_and_stimulate_api(SimpleCell): """Test the API for recording and stimulating.""" - nseg_per_branch = 2 - depth = 2 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - parents = jnp.asarray(parents) - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) - cell = jx.Cell(branch, parents=parents) + cell = SimpleCell(3, 2, copy=True) cell.branch(0).loc(0.0).record() cell.branch(1).loc(1.0).record() @@ -37,16 +30,9 @@ def test_record_and_stimulate_api(): cell.delete_stimuli() -def test_record_shape(): +def test_record_shape(SimpleCell): """Test the API for recording and stimulating.""" - nseg_per_branch = 2 - depth = 2 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - parents = jnp.asarray(parents) - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) - cell = jx.Cell(branch, parents=parents) + cell = SimpleCell(3, 2, copy=True) current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) cell.branch(1).loc(1.0).stimulate(current) @@ -64,7 +50,7 @@ def test_record_shape(): ), f"Shape of recordings ({voltages.shape}) is not right." -def test_record_synaptic_and_membrane_states(): +def test_record_synaptic_and_membrane_states(SimpleNet): """Tests recording of synaptic and membrane states. Tests are functional, not just API. They test whether the voltage and synaptic @@ -73,10 +59,7 @@ def test_record_synaptic_and_membrane_states(): _ = np.random.seed(0) # Seed because connectivity is at random postsyn locs. - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, parents=[-1]) - net = jx.Network([cell for _ in range(3)]) + net = SimpleNet(3, 1, 4, copy=True) net.insert(HH()) fully_connect(net.cell([0]), net.cell([1]), IonotropicSynapse()) diff --git a/tests/test_set_ncomp.py b/tests/test_set_ncomp.py index 81bff586..f154ef62 100644 --- a/tests/test_set_ncomp.py +++ b/tests/test_set_ncomp.py @@ -19,30 +19,27 @@ @pytest.mark.parametrize( "property", ["radius", "capacitance", "length", "axial_resistivity"] ) -def test_raise_for_heterogenous_modules(property): - comp = jx.Compartment() - branch0 = jx.Branch(comp, nseg=4) - branch1 = jx.Branch(comp, nseg=4) +def test_raise_for_heterogenous_modules(property, SimpleBranch): + branch0 = SimpleBranch(4, copy=True) + branch1 = SimpleBranch(4, copy=True) branch1.comp(1).set(property, 1.5) cell = jx.Cell([branch0, branch1], parents=[-1, 0]) with pytest.raises(ValueError): cell.branch(1).set_ncomp(2) -def test_raise_for_heterogenous_channel_existance(): - comp = jx.Compartment() - branch0 = jx.Branch(comp, nseg=4) - branch1 = jx.Branch(comp, nseg=4) +def test_raise_for_heterogenous_channel_existance(SimpleBranch): + branch0 = SimpleBranch(4, copy=True) + branch1 = SimpleBranch(4, copy=True) branch1.comp(2).insert(HH()) cell = jx.Cell([branch0, branch1], parents=[-1, 0]) with pytest.raises(ValueError): cell.branch(1).set_ncomp(2) -def test_raise_for_heterogenous_channel_properties(): - comp = jx.Compartment() - branch0 = jx.Branch(comp, nseg=4) - branch1 = jx.Branch(comp, nseg=4) +def test_raise_for_heterogenous_channel_properties(SimpleBranch): + branch0 = SimpleBranch(4, copy=True) + branch1 = SimpleBranch(4, copy=True) branch1.insert(HH()) branch1.comp(3).set("HH_gNa", 0.5) cell = jx.Cell([branch0, branch1], parents=[-1, 0]) @@ -50,54 +47,47 @@ def test_raise_for_heterogenous_channel_properties(): cell.branch(1).set_ncomp(2) -def test_raise_for_entire_cells(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1, 0, 0]) +def test_raise_for_entire_cells(SimpleCell): + cell = SimpleCell(3, 4) with pytest.raises(AssertionError): cell.set_ncomp(2) -def test_raise_for_networks(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell1 = jx.Cell(branch, parents=[-1, 0, 0]) - cell2 = jx.Cell(branch, parents=[-1, 0, 0]) +def test_raise_for_networks(SimpleCell): + cell1 = SimpleCell(3, 4) + cell2 = SimpleCell(3, 4) net = jx.Network([cell1, cell2]) with pytest.raises(AssertionError): net.cell(0).branch(1).set_ncomp(2) -def test_raise_for_recording(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1, 0]) +def test_raise_for_recording(SimpleCell): + cell = SimpleCell(3, 2, copy=True) cell.branch(0).comp(0).record() with pytest.raises(AssertionError): cell.branch(1).set_ncomp(2) -def test_raise_for_stimulus(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1, 0]) +def test_raise_for_stimulus(SimpleCell): + cell = SimpleCell(3, 2, copy=True) cell.branch(0).comp(0).stimulate(0.4 * jnp.ones(100)) with pytest.raises(AssertionError): cell.branch(1).set_ncomp(2) @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) -def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch(new_ncomp): +def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch( + new_ncomp, SimpleBranch +): """Test whether a module built from scratch matches module built with `set_ncomp()`. This makes one branch, whose `ncomp` is not modified, heterogenous. """ - comp = jx.Compartment() - branch1 = jx.Branch(comp, nseg=new_ncomp) + branch1 = SimpleBranch(new_ncomp, copy=True) # The second branch is originally instantiated to have 4 ncomp, but is later # modified to have `new_ncomp` compartments. - branch2 = jx.Branch(comp, nseg=4) + branch2 = SimpleBranch(4, copy=True) branch2.comp("all").set("length", 10.0) total_branch_len = 4 * 10.0 @@ -118,14 +108,15 @@ def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch(new_ncomp): @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) -def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell(new_ncomp): +def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell( + new_ncomp, SimpleBranch +): """Test whether a module built from scratch matches module built with `set_ncomp()`.""" - comp = jx.Compartment() - branch1 = jx.Branch(comp, nseg=new_ncomp) + branch1 = SimpleBranch(new_ncomp, copy=True) # The second branch is originally instantiated to have 4 ncomp, but is later # modified to have `new_ncomp` compartments. - branch2 = jx.Branch(comp, nseg=4) + branch2 = SimpleBranch(4, copy=True) branch2.comp("all").set("length", 10.0) total_branch_len = 4 * 10.0 @@ -150,13 +141,13 @@ def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell(new_ncomp): @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) @pytest.mark.parametrize("file", ["morph_250.swc"]) -def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file): +def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file, SimpleMorphCell): """Test if the radiuses and lenghts of an SWC morph are reconstructed correctly.""" dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) - cell1 = jx.read_swc(fname, nseg=new_ncomp, max_branch_len=2000.0) - cell2 = jx.read_swc(fname, nseg=4, max_branch_len=2000.0) + cell1 = SimpleMorphCell(fname, nseg=new_ncomp, copy=True) + cell2 = SimpleMorphCell(fname, nseg=4, copy=True) for b in range(cell2.total_nbranches): cell2.branch(b).set_ncomp(new_ncomp) @@ -171,13 +162,13 @@ def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file): @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) @pytest.mark.parametrize("file", ["morph_250.swc"]) -def test_simulation_accuracy_swc_init_vs_set_ncomp(new_ncomp, file): +def test_simulation_accuracy_swc_init_vs_set_ncomp(new_ncomp, file, SimpleMorphCell): """Test whether an SWC initially built with 4 ncomp works after `set_ncomp()`.""" dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) - cell1 = jx.read_swc(fname, nseg=new_ncomp, max_branch_len=2000.0) - cell2 = jx.read_swc(fname, nseg=4, max_branch_len=2000.0) + cell1 = SimpleMorphCell(fname, nseg=new_ncomp, copy=True) + cell2 = SimpleMorphCell(fname, nseg=4, copy=True) for b in range(cell2.total_nbranches): cell2.branch(b).set_ncomp(new_ncomp) diff --git a/tests/test_solver.py b/tests/test_solver.py index 251b96b7..be42f8d5 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -23,24 +23,17 @@ def test_exp_euler(x_inf): assert np.abs(fwd_euler - exp_euler) / np.abs(fwd_euler) < 1e-4 -def test_fwd_euler_and_crank_nicolson(): +def test_fwd_euler_and_crank_nicolson(SimpleNet): """FWD Euler does not yet support branched cells, but comps, branches, nets work. Tests whether forward Euler and Crank-Nicolson are sufficiently close to implicit Euler.""" - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - branch.insert(HH()) - cell = jx.Cell(branch, parents=[-1]) - net = jx.Network([cell for _ in range(2)]) + net = SimpleNet(2, 1, 4, connect=True, copy=True) current = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0) net.cell(0).branch(0).comp(0).stimulate(current) net.cell(1).branch(0).comp(3).record() - pre = net.cell(0).branch(0).comp(0) - post = net.cell(1).branch(0).comp(0) - connect(pre, post, IonotropicSynapse()) net.IonotropicSynapse.set("IonotropicSynapse_gS", 0.001) # As expected, using significantly shorter compartments or lower r_a leads to NaN diff --git a/tests/test_swc.py b/tests/test_swc.py index 09c2b44c..52757b99 100644 --- a/tests/test_swc.py +++ b/tests/test_swc.py @@ -105,7 +105,7 @@ def test_swc_radius(file): @pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"]) -def test_swc_voltages(file): +def test_swc_voltages(file, SimpleMorphCell): """Check if voltages of SWC recording match. To match the branch indices between NEURON and jaxley, we rely on comparing the @@ -143,7 +143,7 @@ def test_swc_voltages(file): ####################### jaxley ################## _, pathlengths, _, _, _ = jx.utils.swc.swc_to_jaxley(fname, max_branch_len=2_000) - cell = jx.read_swc(fname, nseg_per_branch, max_branch_len=2_000.0) + cell = SimpleMorphCell(fname, nseg_per_branch, max_branch_len=2_000.0, copy=True) cell.insert(HH()) trunk_inds = [1, 4, 5, 13, 15, 21, 23, 24, 29, 33] diff --git a/tests/test_syn.py b/tests/test_syn.py index 89d5ab6f..656bfa42 100644 --- a/tests/test_syn.py +++ b/tests/test_syn.py @@ -16,12 +16,9 @@ from jaxley.synapses import IonotropicSynapse, Synapse, TestSynapse -def test_set_and_querying_params_one_type(): +def test_set_and_querying_params_one_type(SimpleNet): """Test if the correct parameters are set if one type of synapses is inserted.""" - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1]) - net = jx.Network([cell for _ in range(4)]) + net = SimpleNet(4, 1, 4, copy=True) for pre_ind in [0, 1]: for post_ind in [2, 3]: diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index 136a38d7..69b4ece8 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -17,16 +17,13 @@ from jaxley.synapses import IonotropicSynapse, Synapse, TanhRateSynapse, TestSynapse -def test_multiparameter_setting(): +def test_multiparameter_setting(SimpleNet): """ Test if the correct parameters are set if one type of synapses is inserted. Tests global index dropping: d4daaf019596589b9430219a15f1dda0b1c34d85 """ - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1]) - net = jx.Network([cell for _ in range(2)]) + net = SimpleNet(2, 1, 4, copy=True) pre = net.cell(0).branch(0).loc(0.0) post = net.cell(1).branch(0).loc(0.0) @@ -59,13 +56,10 @@ def _get_synapse_view(net, synapse_name, single_idx=1, double_idxs=[2, 3]): @pytest.mark.parametrize( "synapse_type", [IonotropicSynapse, TanhRateSynapse, TestSynapse] ) -def test_set_and_querying_params_one_type(synapse_type): +def test_set_and_querying_params_one_type(synapse_type, SimpleNet): """Test if the correct parameters are set if one type of synapses is inserted.""" synapse_type = synapse_type() - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1]) - net = jx.Network([cell for _ in range(4)]) + net = SimpleNet(4, 1, 4, copy=True) for pre_ind in [0, 1]: for post_ind in [2, 3]: @@ -100,13 +94,10 @@ def test_set_and_querying_params_one_type(synapse_type): @pytest.mark.parametrize("synapse_type", [TanhRateSynapse, TestSynapse]) -def test_set_and_querying_params_two_types(synapse_type): +def test_set_and_querying_params_two_types(synapse_type, SimpleNet): """Test whether the correct parameters are set.""" synapse_type = synapse_type() - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1]) - net = jx.Network([cell for _ in range(4)]) + net = SimpleNet(4, 1, 4, copy=True) for pre_ind in [0, 1]: for post_ind, synapse in zip([2, 3], [IonotropicSynapse(), synapse_type]): @@ -159,15 +150,12 @@ def test_set_and_querying_params_two_types(synapse_type): @pytest.mark.parametrize("synapse_type", [TanhRateSynapse, TestSynapse]) -def test_shuffling_order_of_set(synapse_type): +def test_shuffling_order_of_set(synapse_type, SimpleNet): """Test whether the result is the same if the order of synapses is changed.""" synapse_type = synapse_type() - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1]) - net1 = jx.Network([cell for _ in range(4)]) - net2 = jx.Network([cell for _ in range(4)]) + net1 = SimpleNet(4, 1, 4, copy=True) + net2 = SimpleNet(4, 1, 4, copy=True) connect( net1.cell(0).branch(0).loc(1.0), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 323441fc..a5672a68 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -151,10 +151,8 @@ def test_correct(transform): "transform", [jt.SigmoidTransform(-2, 2), jt.SoftplusTransform(2), jt.NegSoftplusTransform(2)], ) -def test_user_api(transform): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=2) - cell = jx.Cell(branch, parents=[-1, 0, 0]) +def test_user_api(transform, SimpleCell): + cell = SimpleCell(3, 2, copy=True) cell.branch("all").make_trainable("radius") cell.branch(2).make_trainable("radius") diff --git a/tests/test_viewing.py b/tests/test_viewing.py index ca2983d6..cb75130b 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -23,11 +23,10 @@ from jaxley.utils.solver_utils import JaxleySolveIndexer -def test_getitem(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(4)]) - cell = jx.Cell([branch for _ in range(3)], parents=jnp.asarray([-1, 0, 0])) - net = jx.Network([cell for _ in range(3)]) +def test_getitem(SimpleBranch, SimpleCell, SimpleNet): + branch = SimpleBranch(4) + cell = SimpleCell(3, 4) + net = SimpleNet(3, 3, 4) # test API equivalence assert all(net.cell(0).branch(0).show() == net[0, 0].show()) @@ -57,11 +56,8 @@ def test_getitem(): pass -def test_loc_v_comp(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(4)]) - - cum_nseg = branch.cumsum_nseg +def test_loc_v_comp(SimpleBranch): + branch = SimpleBranch(4) nsegs = branch.nseg_per_branch branch_ind = 0 @@ -75,11 +71,11 @@ def test_loc_v_comp(): assert np.all(branch.comp(inferred_ind).show() == branch.loc(0.4).show()) -def test_shape(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(4)]) - cell = jx.Cell([branch for _ in range(3)], parents=jnp.asarray([-1, 0, 0])) - net = jx.Network([cell for _ in range(3)]) +def test_shape(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): + comp = SimpleComp() + branch = SimpleBranch(4) + cell = SimpleCell(3, 4) + net = SimpleNet(3, 3, 4) assert net.shape == (3, 3 * 3, 3 * 3 * 4) assert cell.shape == (3, 3 * 4) @@ -98,11 +94,10 @@ def test_shape(): assert net.cell(0).branch(0).comp(0).shape == (1, 1, 1) -def test_set_and_insert(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(4)]) - cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1])) - net = jx.Network([cell for _ in range(5)]) +def test_set_and_insert(SimpleBranch, SimpleCell, SimpleNet): + branch = SimpleBranch(4) + cell = SimpleCell(5, 4) + net = SimpleNet(5, 5, 4) net1 = deepcopy(net) net2 = deepcopy(net) net3 = deepcopy(net) @@ -186,11 +181,8 @@ def test_set_and_insert(): assert np.all(cell1.recordings == cell2.recordings) -def test_local_indexing(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(4)]) - cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1])) - net = jx.Network([cell for _ in range(2)]) +def test_local_indexing(SimpleNet): + net = SimpleNet(2, 5, 4) local_idxs = net.nodes[ ["local_cell_index", "local_branch_index", "local_comp_index"] @@ -206,11 +198,10 @@ def test_local_indexing(): global_index += 1 -def test_indexing_a_compartment_of_many_branches(): - comp = jx.Compartment() - branch1 = jx.Branch(comp, nseg=3) - branch2 = jx.Branch(comp, nseg=4) - branch3 = jx.Branch(comp, nseg=5) +def test_indexing_a_compartment_of_many_branches(SimpleBranch): + branch1 = SimpleBranch(nseg=3) + branch2 = SimpleBranch(nseg=4) + branch3 = SimpleBranch(nseg=5) cell1 = jx.Cell([branch1, branch2, branch3], parents=[-1, 0, 0]) cell2 = jx.Cell([branch3, branch2], parents=[-1, 0]) net = jx.Network([cell1, cell2]) @@ -247,16 +238,8 @@ def test_solve_indexer(): assert np.all(idx.upper(branch_inds) == np.asarray([[0, 1, 2], [7, 8, 9]])) -comp = jx.Compartment() -branch = jx.Branch(comp, nseg=3) -cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) -net = jx.Network([cell] * 3) -connect(net[0, 0, 0], net[0, 0, 1], TestSynapse()) - - # make sure all attrs in module also have a corresponding attr in view -@pytest.mark.parametrize("module", [comp, branch, cell, net]) -def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): +def test_view_attrs(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): """Check if all attributes of Module have a corresponding attribute in View. To ensure that View behaves like a Module as much as possible, View should support @@ -289,26 +272,24 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): "_cumsum_nseg_per_cell", ] # for network - for name, attr in module.__dict__.items(): - if name not in exceptions: - # check if attr is in view - view = View(module) - assert hasattr(view, name), f"View missing attribute: {name}" - # check if types match - assert type(getattr(module, name)) == type( - getattr(view, name) - ), f"Type mismatch: {name}, Module type: {type(getattr(module, name))}, View type: {type(getattr(view, name))}" - - -comp = jx.Compartment() -branch = jx.Branch(nseg=4) -cell = jx.Cell([branch] * 5, parents=[-1, 0, 0, 1, 1]) -net = jx.Network([cell] * 2) -connect(net[0, 0, :], net[1, 0, :], TestSynapse()) - - -@pytest.mark.parametrize("module", [comp, branch, cell, net]) -def test_view_supported_index_types(module): + for module in [ + SimpleComp(), + SimpleBranch(2), + SimpleCell(2, 3), + SimpleNet(2, 2, 3, connect=True), + ]: + for name, attr in module.__dict__.items(): + if name not in exceptions: + # check if attr is in view + view = View(module) + assert hasattr(view, name), f"View missing attribute: {name}" + # check if types match + assert type(getattr(module, name)) == type( + getattr(view, name) + ), f"Type mismatch: {name}, Module type: {type(getattr(module, name))}, View type: {type(getattr(view, name))}" + + +def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): """Check if different ways to index into Modules/Views work correctly.""" # test int, range, slice, list, np.array, pd.Index index_types = [ @@ -321,42 +302,46 @@ def test_view_supported_index_types(module): np.array([True, False, True, False] * 100)[: len(module.nodes)], ] - # comp.comp is not allowed - all_inds = module.nodes.index.to_numpy() - if not isinstance(module, jx.Compartment): - # `_reformat_index` should always return a np.ndarray - for index in index_types: - assert isinstance( - module._reformat_index(index), np.ndarray - ), f"Failed for {type(index)}" - - # test indexing into module and view - assert module.comp(index), f"Failed for {type(index)}" - assert View(module).comp(index), f"Failed for {type(index)}" - - expected_inds = all_inds[index] - assert np.all(module.select(nodes=index).nodes.index == expected_inds) - - # for loc test float and list of floats - assert module.loc(0.0), "Failed for float" - assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" - else: - with pytest.raises(AssertionError): - module.comp(0) - - if isinstance(module, jx.Network): - all_inds = module.edges.index.to_numpy() - for index in index_types[:-1] + [np.array([True, False, True, False])]: - expected_inds = all_inds[index] - assert np.all(net.select(edges=index).edges.index == expected_inds) - - -def test_select(): + for module in [ + SimpleComp(), + SimpleBranch(4), + SimpleCell(3, 4), + SimpleNet(2, 3, 4, connect=True), + ]: + + # comp.comp is not allowed + all_inds = module.nodes.index.to_numpy() + if not isinstance(module, jx.Compartment): + # `_reformat_index` should always return a np.ndarray + for index in index_types: + assert isinstance( + module._reformat_index(index), np.ndarray + ), f"Failed for {type(index)}" + + # test indexing into module and view + assert module.comp(index), f"Failed for {type(index)}" + assert View(module).comp(index), f"Failed for {type(index)}" + + expected_inds = all_inds[index] + assert np.all(module.select(nodes=index).nodes.index == expected_inds) + + # for loc test float and list of floats + assert module.loc(0.0), "Failed for float" + assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + else: + with pytest.raises(AssertionError): + module.comp(0) + + if isinstance(module, jx.Network): + all_inds = module.edges.index.to_numpy() + for index in index_types[:-1] + [np.array([True, False, True, False])]: + expected_inds = all_inds[index] + assert np.all(module.select(edges=index).edges.index == expected_inds) + + +def test_select(SimpleNet): """Ensure `select` works correctly and returns expected View of Modules.""" - comp = jx.Compartment() - branch = jx.Branch([comp] * 3) - cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) - net = jx.Network([cell] * 3) + net = SimpleNet(3, 3, 2, connect=False, copy=True) connect(net[0, 0, :], net[1, 0, :], TestSynapse()) np.random.seed(0) @@ -393,12 +378,10 @@ def test_select(): ), "Selecting nodes and edges by index failed for edges." -def test_viewing(): +def test_viewing(SimpleCell, SimpleNet): """Test that the View object is working correctly.""" - comp = jx.Compartment() - branch = jx.Branch([comp] * 3) - cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) - net = jx.Network([cell] * 3) + cell = SimpleCell(3, 3) + net = SimpleNet(3, 3, 3) # test parameter sharing works correctly nodes1 = net.branch(0).comp("all").nodes @@ -447,11 +430,9 @@ def test_viewing(): net.scope("global").comp(999) # Nothing should be in View -def test_scope(): +def test_scope(SimpleCell): """Ensure scope has the intended effect for Modules and Views.""" - comp = jx.Compartment() - branch = jx.Branch([comp] * 3) - cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + cell = SimpleCell(3, 3) view = cell.scope("global").branch(1) assert view._scope == "global" @@ -481,11 +462,9 @@ def test_scope(): ) -def test_context_manager(): +def test_context_manager(SimpleCell): """Test that context manager works correctly for Module.""" - comp = jx.Compartment() - branch = jx.Branch([comp] * 3) - cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + cell = SimpleCell(3, 3) with cell.branch(0).comp(0) as comp: comp.set("v", -71) @@ -506,11 +485,10 @@ def test_context_manager(): ), "Context management of View not working." -def test_iter(): +def test_iter(SimpleBranch): """Test that __iter__ works correctly for all modules.""" - comp = jx.Compartment() - branch1 = jx.Branch([comp] * 2) - branch2 = jx.Branch([comp] * 3) + branch1 = SimpleBranch(2) + branch2 = SimpleBranch(3) cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) net = jx.Network([cell] * 2) @@ -566,12 +544,9 @@ def test_iter(): assert np.all(cell.nodes["v"] == -73), "Setting parameters with __iter__ failed." -def test_synapse_and_channel_filtering(): +def test_synapse_and_channel_filtering(SimpleNet): """Test that synapses and channels are filtered correctly by View.""" - comp = jx.Compartment() - branch = jx.Branch([comp] * 3) - cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) - net = jx.Network([cell] * 3) + net = SimpleNet(3, 3, 3, connect=False, copy=True) net.insert(HH()) connect(net[0, 0, :], net[1, 0, :], TestSynapse()) @@ -595,10 +570,10 @@ def test_synapse_and_channel_filtering(): assert np.all(edges1 == edges2) -def test_view_equals_module(): +def test_view_equals_module(SimpleComp, SimpleBranch): """Test that View behaves the same as Module for important attrs and methods.""" - comp = jx.Compartment() - branch = jx.Branch([comp] * 3) + comp = SimpleComp(copy=True) + branch = SimpleBranch(3, copy=True) comp.insert(HH()) branch.comp([0, 1]).insert(HH()) From 3adef2bf91ef64bceae8a876b3b61ce54afd16a7 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 14 Nov 2024 11:28:39 +0100 Subject: [PATCH 02/15] wip: save wip refactoring with Module test fixtures --- tests/conftest.py | 40 +++++++--- tests/test_api_equivalence.py | 68 +++++++---------- tests/test_cell_matches_branch.py | 16 +--- tests/test_channels.py | 18 ++--- tests/test_clamp.py | 47 +++++------- tests/test_connection.py | 29 +++---- tests/test_data_feeding.py | 12 +-- tests/test_distance.py | 10 +-- tests/test_grad.py | 9 +-- tests/test_groups.py | 40 ++++------ tests/test_make_trainable.py | 117 ++++++++++------------------- tests/test_moving.py | 22 +++--- tests/test_record_and_stimulate.py | 10 +-- tests/test_set_ncomp.py | 32 ++++---- tests/test_solver.py | 2 +- tests/test_swc.py | 22 +++--- tests/test_syn.py | 2 +- tests/test_synapse_indexing.py | 10 +-- tests/test_transforms.py | 2 +- tests/test_viewing.py | 6 +- 20 files changed, 213 insertions(+), 301 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 03963314..8cf46ac5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + import os from copy import deepcopy @@ -14,18 +17,18 @@ def SimpleComp(): comp = jx.Compartment() - def get_comp(copy=False): + def get_comp(copy=True): return deepcopy(comp) if copy else comp yield get_comp - comp = None + del comp @pytest.fixture(scope="session") def SimpleBranch(SimpleComp): branches = {} - def branch_w_shape(nseg, copy=False): + def branch_w_shape(nseg, copy=True): if nseg not in branches: branches[nseg] = jx.Branch([SimpleComp()] * nseg) return deepcopy(branches[nseg]) if copy else branches[nseg] @@ -38,15 +41,15 @@ def branch_w_shape(nseg, copy=False): def SimpleCell(SimpleBranch): cells = {} - def cell_w_shape(nbranches, nseg_per_branch, copy=False): - if key := (nbranches, nseg_per_branch) not in cells: + def cell_w_shape(nbranches, nseg, copy=True): + if key := (nbranches, nseg) not in cells: parents = [-1] depth = 0 while nbranches > len(parents): parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] depth += 1 parents = parents[:nbranches] - cells[key] = jx.Cell([SimpleBranch(nseg_per_branch)] * nbranches, parents) + cells[key] = jx.Cell([SimpleBranch(nseg)] * nbranches, parents) return deepcopy(cells[key]) if copy else cells[key] yield cell_w_shape @@ -57,9 +60,9 @@ def cell_w_shape(nbranches, nseg_per_branch, copy=False): def SimpleNet(SimpleCell): nets = {} - def net_w_shape(n_cells, nbranches, nseg_per_branch, connect=False, copy=False): - if key := (n_cells, nbranches, nseg_per_branch, connect) not in nets: - net = jx.Network([SimpleCell(nbranches, nseg_per_branch)] * n_cells) + def net_w_shape(ncells, nbranches, nseg, connect=False, copy=True): + if key := (ncells, nbranches, nseg, connect) not in nets: + net = jx.Network([SimpleCell(nbranches, nseg)] * ncells) if connect: jx.connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) nets[key] = net @@ -76,7 +79,7 @@ def SimpleMorphCell(): cells = {} - def cell_w_params(fname=None, nseg=2, max_branch_len=2_000.0, copy=False): + def cell_w_params(fname=None, nseg=2, max_branch_len=2_000.0, copy=True): fname = default_fname if fname is None else fname if key := (fname, nseg, max_branch_len) not in cells: cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True) @@ -84,3 +87,20 @@ def cell_w_params(fname=None, nseg=2, max_branch_len=2_000.0, copy=False): yield cell_w_params cells = {} + + +@pytest.fixture(scope="session") +def swc2jaxley(): + dirname = os.path.dirname(__file__) + default_fname = os.path.join(dirname, "swc_files", "morph.swc") # n120 + + params = {} + + def swc2jaxley_params(fname=None, max_branch_len=2_000.0, sort=True): + fname = default_fname if fname is None else fname + if key := (fname, max_branch_len, sort) not in params: + params[key] = jx.utils.swc.swc_to_jaxley(fname, max_branch_len, sort) + return params[key] + + yield swc2jaxley_params + params = {} diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index 85207b93..fa11f839 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -48,9 +48,9 @@ def test_api_equivalence_morphology(): ), "Voltages do not match between morphology APIs." -def test_solver_backends_comp(): +def test_solver_backends_comp(SimpleComp): """Test whether ways of adding synapses are equivalent.""" - comp = jx.Compartment() + comp = SimpleComp() current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) comp.stimulate(current) @@ -64,10 +64,9 @@ def test_solver_backends_comp(): assert max_error < 1e-8, f"{message} thomas/stone. Error={max_error}" -def test_solver_backends_branch(): +def test_solver_backends_branch(SimpleBranch): """Test whether ways of adding synapses are equivalent.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) + branch = SimpleBranch(4) current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) branch.loc(0.0).stimulate(current) @@ -81,16 +80,14 @@ def test_solver_backends_branch(): assert max_error < 1e-8, f"{message} thomas/stone. Error={max_error}" -def test_solver_backends_cell(): +def test_solver_backends_cell(SimpleCell): """Test whether ways of adding synapses are equivalent.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + cell = SimpleCell(4, 4) current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) cell.branch(0).loc(0.0).stimulate(current) cell.branch(0).loc(0.5).record() - cell.branch(4).loc(0.5).record() + cell.branch(3).loc(0.5).record() voltages_jx_thomas = jx.integrate(cell, voltage_solver="jaxley.thomas") voltages_jx_stone = jx.integrate(cell, voltage_solver="jaxley.stone") @@ -100,14 +97,10 @@ def test_solver_backends_cell(): assert max_error < 1e-8, f"{message} thomas/stone. Error={max_error}" -def test_solver_backends_net(): +def test_solver_backends_net(SimpleNet): """Test whether ways of adding synapses are equivalent.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell1 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) - cell2 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + net = SimpleNet(2, 4, 4) - net = jx.Network([cell1, cell2]) connect( net.cell(0).branch(0).loc(1.0), net.cell(1).branch(4).loc(1.0), @@ -122,7 +115,7 @@ def test_solver_backends_net(): current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() - net.cell(1).branch(4).loc(0.5).record() + net.cell(1).branch(3).loc(0.5).record() voltages_jx_thomas = jx.integrate(net, voltage_solver="jaxley.thomas") voltages_jx_stone = jx.integrate(net, voltage_solver="jaxley.stone") @@ -132,39 +125,35 @@ def test_solver_backends_net(): assert max_error < 1e-8, f"{message} thomas/stone. Error={max_error}" -def test_api_equivalence_synapses(): +def test_api_equivalence_synapses(SimpleNet): """Test whether ways of adding synapses are equivalent.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell1 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) - cell2 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + net1 = SimpleNet(2, 4, 4) - net1 = jx.Network([cell1, cell2]) connect( net1.cell(0).branch(0).loc(1.0), - net1.cell(1).branch(4).loc(1.0), + net1.cell(1).branch(3).loc(1.0), IonotropicSynapse(), ) connect( net1.cell(1).branch(1).loc(0.8), - net1.cell(0).branch(4).loc(0.1), + net1.cell(0).branch(3).loc(0.1), IonotropicSynapse(), ) - net2 = jx.Network([cell1, cell2]) + net2 = SimpleNet(2, 4, 4) pre = net2.cell(0).branch(0).loc(1.0) - post = net2.cell(1).branch(4).loc(1.0) + post = net2.cell(1).branch(3).loc(1.0) connect(pre, post, IonotropicSynapse()) pre = net2.cell(1).branch(1).loc(0.8) - post = net2.cell(0).branch(4).loc(0.1) + post = net2.cell(0).branch(3).loc(0.1) connect(pre, post, IonotropicSynapse()) for net in [net1, net2]: current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() - net.cell(1).branch(4).loc(0.5).record() + net.cell(1).branch(3).loc(0.5).record() voltages1 = jx.integrate(net1) voltages2 = jx.integrate(net2) @@ -174,10 +163,8 @@ def test_api_equivalence_synapses(): ), "Voltages do not match between synapse APIs." -def test_api_equivalence_continued_simulation(): - comp = jx.Compartment() - branch = jx.Branch(comp, 2) - cell = jx.Cell(branch, parents=[-1, 0, 0]) +def test_api_equivalence_continued_simulation(SimpleCell): + cell = SimpleCell(3, 2) cell.insert(HH()) cell[0, 1].record() @@ -189,7 +176,7 @@ def test_api_equivalence_continued_simulation(): assert np.max(np.abs(v1 - v2)) < 1e-8 -def test_api_equivalence_network_matches_cell(): +def test_api_equivalence_network_matches_cell(SimpleBranch): """Test whether a network with w=0 synapses equals the individual cells. This runs an unequal number of compartments per branch.""" @@ -197,10 +184,9 @@ def test_api_equivalence_network_matches_cell(): t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.1, dt, t_max) - comp = jx.Compartment() - branch1 = jx.Branch(comp, nseg=1) - branch2 = jx.Branch(comp, nseg=2) - branch3 = jx.Branch(comp, nseg=3) + branch1 = SimpleBranch(nseg=1) + branch2 = SimpleBranch(nseg=2) + branch3 = SimpleBranch(nseg=3) cell1 = jx.Cell([branch1, branch2, branch3], parents=[-1, 0, 0]) cell2 = jx.Cell([branch1, branch2], parents=[-1, 0]) cell1.insert(HH()) @@ -232,10 +218,8 @@ def test_api_equivalence_network_matches_cell(): assert max_error < 1e-8, f"Error is {max_error}" -def test_api_init_step_to_integrate(): - comp = jx.Compartment() - branch = jx.Branch(comp, 2) - cell = jx.Cell(branch, parents=[-1, 0, 0]) +def test_api_init_step_to_integrate(SimpleCell): + cell = SimpleCell(3, 2) cell.insert(HH()) cell[0, 1].record() diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index 4b43d4ad..a1a3d274 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -15,11 +15,8 @@ from jaxley.channels import HH -def _run_long_branch(dt, t_max): - nseg_per_branch = 8 - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) +def _run_long_branch(dt, t_max, SimpleBranch): + branch = SimpleBranch(8) branch.insert(HH()) branch.loc("all").make_trainable("radius", 1.0) @@ -38,13 +35,8 @@ def loss(params): return l, g -def _run_short_branches(dt, t_max): - nseg_per_branch = 4 - parents = jnp.asarray([-1, 0]) - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) - cell = jx.Cell(branch, parents=parents) +def _run_short_branches(dt, t_max, SimpleCell): + cell = SimpleCell(2, 4) cell.insert(HH()) cell.branch("all").loc("all").make_trainable("radius", 1.0) diff --git a/tests/test_channels.py b/tests/test_channels.py index 41024040..5870a10f 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -164,15 +164,13 @@ def test_integration_with_renamed_channels(): assert np.invert(np.any(np.isnan(v))) -def test_init_states(): +def test_init_states(SimpleCell): """Functional test for `init_states()`. Checks whether, if everything is initialized in its steady state, the voltage after 10ms is almost exactly the same as after 0ms. """ - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) + cell = SimpleCell(2, 4) cell.branch(0).loc(0.0).record() cell.branch(0).insert(Na()) @@ -257,7 +255,7 @@ def m_gate(v, cai, q10): return m_inf, tau_m -def test_init_states_complex_channel(): +def test_init_states_complex_channel(SimpleCell): """Test for `init_states()` with a more complicated channel model. The channel model used for this test uses the `states` in `init_state` and it also @@ -265,9 +263,7 @@ def test_init_states_complex_channel(): an issue I had with Jaxley in v0.2.0 (fixed in v0.2.1). """ ## Create cell - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=1) - cell = jx.Cell(branch, parents=[-1, 0, 0]) + cell = SimpleCell(3, 1) # CA channels. cell.branch([0, 1]).insert(CaNernstReversal()) @@ -283,7 +279,7 @@ def test_init_states_complex_channel(): assert np.invert(np.any(np.isnan(voltages))), "NaN voltage found" -def test_multiple_channel_currents(): +def test_multiple_channel_currents(SimpleCell): """Test whether all channels can""" class User(Channel): @@ -334,9 +330,7 @@ def compute_current(self, states, v, params): dt = 0.025 # ms t_max = 10.0 # ms - comp = jx.Compartment() - branch = jx.Branch(comp, 1) - cell = jx.Cell(branch, parents=[-1]) + cell = SimpleCell(1, 1) cell.branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, dt, t_max)) cell.insert(User()) diff --git a/tests/test_clamp.py b/tests/test_clamp.py index 8253cd5b..dd8f74f6 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -18,8 +18,8 @@ from jaxley.channels import HH, CaL, CaT, Channel, K, Km, Leak, Na -def test_clamp_pointneuron(): - comp = jx.Compartment() +def test_clamp_pointneuron(SimpleComp): + comp = SimpleComp() comp.insert(HH()) comp.record() comp.clamp("v", -50.0 * jnp.ones((1000,))) @@ -28,8 +28,8 @@ def test_clamp_pointneuron(): assert np.all(v[:, 1:] == -50.0) -def test_clamp_currents(): - comp = jx.Compartment() +def test_clamp_currents(SimpleComp): + comp = SimpleComp() comp.insert(HH()) comp.record("i_HH") @@ -49,12 +49,8 @@ def test_clamp_currents(): assert np.all(np.isclose(i1, i2)) -def test_clamp_synapse(): - comp = jx.Compartment() - branch = jx.Branch(comp, 1) - cell1 = jx.Cell(branch, [-1]) - cell2 = jx.Cell(branch, [-1]) - net = jx.Network([cell1, cell2]) +def test_clamp_synapse(SimpleNet): + net = SimpleNet(2, 1, 1) connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) net.record("IonotropicSynapse_s") @@ -76,9 +72,8 @@ def test_clamp_synapse(): assert np.all(np.isclose(s1, s2)) -def test_clamp_multicompartment(): - comp = jx.Compartment() - branch = jx.Branch(comp, 4) +def test_clamp_multicompartment(SimpleBranch): + branch = SimpleBranch(4) branch.insert(HH()) branch.record() branch.comp(0).clamp("v", -50.0 * jnp.ones((1000,))) @@ -92,12 +87,10 @@ def test_clamp_multicompartment(): assert np.all(np.std(v[1:, 1:], axis=1) > 0.1) -def test_clamp_and_stimulate_api(): +def test_clamp_and_stimulate_api(SimpleCell): """Ensure proper behaviour when `.clamp()` and `.stimulate()` are combined.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell1 = jx.Cell(branch, [-1]) - cell2 = jx.Cell(branch, [-1]) + cell1 = SimpleCell(1, 4) + cell2 = SimpleCell(1, 4) net = jx.Network([cell1, cell2]) net.insert(HH()) @@ -123,9 +116,9 @@ def test_clamp_and_stimulate_api(): assert np.max(np.abs(vs1 - vs2)) < 1e-8 -def test_data_clamp(): +def test_data_clamp(SimpleComp): """Data clamp with no stimuli or data_stimuli, and no t_max (should get defined by the clamp).""" - comp = jx.Compartment() + comp = SimpleComp() comp.insert(HH()) comp.record() clamp = -50.0 * jnp.ones((1000,)) @@ -144,9 +137,9 @@ def simulate(clamp): assert np.all(s[:, 1:] == -50.0) -def test_data_clamp_and_data_stimulate(): +def test_data_clamp_and_data_stimulate(SimpleComp): """In theory people shouldn't use these two together, but at least it shouldn't break.""" - comp = jx.Compartment() + comp = SimpleComp() comp.insert(HH()) comp.record() clamp = -50.0 * jnp.ones((1000,)) @@ -167,9 +160,9 @@ def simulate(clamp, stim): assert np.all(s[:, 1:] == -50.0) -def test_data_clamp_and_stimulate(): +def test_data_clamp_and_stimulate(SimpleComp): """Test that data clamp overrides a previously set stimulus.""" - comp = jx.Compartment() + comp = SimpleComp() comp.insert(HH()) comp.record() clamp = -50.0 * jnp.ones((1000,)) @@ -187,9 +180,9 @@ def simulate(clamp): assert np.all(s[:, 1:] == -50.0) -def test_data_clamp_and_clamp(): +def test_data_clamp_and_clamp(SimpleComp): """Test that data clamp can override (same loc.) and add (another loc.) to clamp.""" - comp = jx.Compartment() + comp = SimpleComp() comp.insert(HH()) comp.record() clamp1 = -50.0 * jnp.ones((1000,)) @@ -208,7 +201,7 @@ def simulate(clamp): s = jitted_simulate(clamp2) assert np.all(s[:, 1:] == -60.0) - comp2 = jx.Compartment() + comp2 = SimpleComp() comp2.insert(HH()) branch1 = jx.Branch(comp, 4) branch2 = jx.Branch(comp2, 4) diff --git a/tests/test_connection.py b/tests/test_connection.py index 5178d24b..c40f1e22 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -21,12 +21,11 @@ from jaxley.synapses import IonotropicSynapse, TestSynapse -def test_connect(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(8)]) - cell = jx.Cell([branch for _ in range(4)], parents=np.array([-1, 0, 0, 0])) - net1 = jx.Network([cell for _ in range(4)]) - net2 = jx.Network([cell for _ in range(4)]) +def test_connect(SimpleBranch, SimpleCell, SimpleNet): + branch = SimpleBranch(4) + cell = SimpleCell(3, 4) + net1 = SimpleNet(4, 3, 8) + net2 = SimpleNet(4, 3, 8) cell1_net1 = net1[0, 0, 0] cell2_net1 = net1[1, 0, 0] @@ -62,7 +61,7 @@ def test_connect(): comp_inds = nodes.loc[first_set_edges[cols].to_numpy().flatten()] branch_inds = comp_inds["global_branch_index"].to_numpy().reshape(-1, 2) cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) - assert np.all(branch_inds == (4, 8)) + assert np.all(branch_inds == (3, 6)) assert (cell_inds == (1, 2)).all() assert ( get_comps(first_set_edges["pre_locs"]) @@ -123,11 +122,8 @@ def test_fully_connect(): ) -def test_sparse_connect(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(4)]) - cell = jx.Cell([branch for _ in range(3)], parents=np.array([-1, 0, 0])) - net = jx.Network([cell for _ in range(4 * 4)]) +def test_sparse_connect(SimpleNet): + net = SimpleNet(4 * 4, 4, 4) _ = np.random.seed(0) for i in range(4): @@ -160,10 +156,8 @@ def test_sparse_connect(): ) -def test_connectivity_matrix_connect(): - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(8)]) - cell = jx.Cell([branch for _ in range(3)], parents=np.array([-1, 0, 0])) +def test_connectivity_matrix_connect(SimpleNet): + net = SimpleNet(4 * 4, 3, 8) _ = np.random.seed(0) n_by_n_adjacency_matrix = np.array( @@ -172,7 +166,6 @@ def test_connectivity_matrix_connect(): incides_of_connected_cells = np.stack(np.where(n_by_n_adjacency_matrix)).T incides_of_connected_cells[:, 1] += 4 - net = jx.Network([cell for _ in range(4 * 4)]) connectivity_matrix_connect( net[:4], net[4:8], TestSynapse(), n_by_n_adjacency_matrix ) @@ -188,7 +181,7 @@ def test_connectivity_matrix_connect(): ) incides_of_connected_cells = np.stack(np.where(m_by_n_adjacency_matrix)).T - net = jx.Network([cell for _ in range(4 * 4)]) + net = SimpleNet(4 * 4, 3, 8) with pytest.raises(AssertionError): connectivity_matrix_connect( net[:4], net[:4], TestSynapse(), m_by_n_adjacency_matrix diff --git a/tests/test_data_feeding.py b/tests/test_data_feeding.py index a5d86565..744c5720 100644 --- a/tests/test_data_feeding.py +++ b/tests/test_data_feeding.py @@ -14,10 +14,8 @@ from jaxley.channels import HH -def test_constant_and_data_stimulus(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=2) - cell = jx.Cell(branch, parents=[-1, 0, 0]) +def test_constant_and_data_stimulus(SimpleCell): + cell = SimpleCell(3, 2) cell.branch(0).loc(0.0).record("v") # test data_stimulate and jit works with trainable parameters see #467 @@ -54,10 +52,8 @@ def simulate(i_amps): assert np.max(diff) < 1e-8 -def test_data_vs_constant_stimulus(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=2) - cell = jx.Cell(branch, parents=[-1, 0, 0]) +def test_data_vs_constant_stimulus(SimpleCell): + cell = SimpleCell(3, 2) cell.branch(0).loc(0.0).record("v") i_amps_data = jnp.asarray([0.01, 0.005]) diff --git a/tests/test_distance.py b/tests/test_distance.py index 29f9ee37..03abdb01 100644 --- a/tests/test_distance.py +++ b/tests/test_distance.py @@ -6,20 +6,14 @@ jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") -import jax.numpy as jnp -import numpy as np -from jax import jit - import jaxley as jx -def test_direct_distance(): +def test_direct_distance(SimpleCell): nseg = 4 length = 15.0 - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=nseg) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + cell = SimpleCell(5, nseg) cell.branch("all").loc("all").set("length", length) cell.compute_xyz() dist = cell.branch(0).loc(0.0).distance(cell.branch(0).loc(1.0)) diff --git a/tests/test_grad.py b/tests/test_grad.py index cf1c7141..f719d570 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -15,14 +15,14 @@ @pytest.mark.parametrize("key", ["HH_m", "v"]) -def test_grad_against_finite_diff_initial_state(key): +def test_grad_against_finite_diff_initial_state(key, SimpleComp): def simulate(): return jnp.sum(jx.integrate(comp)) def simulate_with_params(params): return jnp.sum(jx.integrate(comp, params=params)) - comp = jx.Compartment() + comp = SimpleComp(copy=True) comp.insert(HH()) comp.record() comp.stimulate(jx.step_current(0.1, 0.2, 0.1, 0.025, 5.0)) @@ -50,15 +50,14 @@ def simulate_with_params(params): @pytest.mark.parametrize("key", ["HH_m", "v"]) -def test_branch_grad_against_finite_diff_initial_state(key): +def test_branch_grad_against_finite_diff_initial_state(key, SimpleBranch): def simulate(): return jnp.sum(jx.integrate(branch)) def simulate_with_params(params): return jnp.sum(jx.integrate(branch, params=params)) - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) + branch = SimpleBranch(4) branch.loc(0.0).record() branch.loc(0.0).stimulate(jx.step_current(0.1, 0.2, 0.1, 0.025, 5.0)) branch.loc(0.0).insert(HH()) diff --git a/tests/test_groups.py b/tests/test_groups.py index 00e22ee5..8fd3cfee 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -18,10 +18,8 @@ from jaxley.synapses import IonotropicSynapse -def test_subclassing_groups_cell_api(): - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0, 0, 1, 1]) +def test_subclassing_groups_cell_api(SimpleCell): + cell = SimpleCell(5, 4) cell.branch([0, 3, 4]).add_to_group("subtree") @@ -30,11 +28,8 @@ def test_subclassing_groups_cell_api(): cell.subtree.branch(0).comp("all").make_trainable("length") -def test_subclassing_groups_net_api(): - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1]) - net = jx.Network([cell for _ in range(10)]) +def test_subclassing_groups_net_api(SimpleNet): + net = SimpleNet(10, 2, 4) net.cell([0, 3, 5]).add_to_group("excitatory") @@ -43,13 +38,10 @@ def test_subclassing_groups_net_api(): net.excitatory.cell(0).branch("all").make_trainable("length") -def test_subclassing_groups_net_set_equivalence(): +def test_subclassing_groups_net_set_equivalence(SimpleNet): """Test whether calling `.set` on subclasses group is same as on view.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net1 = jx.Network([cell for _ in range(10)]) - net2 = jx.Network([cell for _ in range(10)]) + net1 = SimpleNet(10, 2, 4) + net2 = SimpleNet(10, 2, 4) net1.cell([0, 3, 5]).add_to_group("excitatory") @@ -65,13 +57,10 @@ def test_subclassing_groups_net_set_equivalence(): assert all(net1.nodes == net2.nodes) -def test_subclassing_groups_net_make_trainable_equivalence(): +def test_subclassing_groups_net_make_trainable_equivalence(SimpleNet): """Test whether calling `.maek_trainable` on subclasses group is same as on view.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net1 = jx.Network([cell for _ in range(10)]) - net2 = jx.Network([cell for _ in range(10)]) + net1 = SimpleNet(10, 2, 4) + net2 = SimpleNet(10, 2, 4) net1.cell([0, 3, 5]).add_to_group("excitatory") @@ -101,13 +90,10 @@ def test_subclassing_groups_net_make_trainable_equivalence(): assert jnp.array_equal(inds1, inds2) -def test_fully_connect_groups_equivalence(): +def test_fully_connect_groups_equivalence(SimpleNet): """Test whether groups can be used with `fully_connect`.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net1 = jx.Network([cell for _ in range(10)]) - net2 = jx.Network([cell for _ in range(10)]) + net1 = SimpleNet(10, 2, 4) + net2 = SimpleNet(10, 2, 4) net1.cell([0, 3, 5]).add_to_group("layer1") net1.cell([6, 8]).add_to_group("layer2") diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index b0fa508f..26fc6021 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -18,17 +18,9 @@ from jaxley.utils.cell_utils import params_to_pstate -def test_make_trainable(): +def test_make_trainable(SimpleCell): """Test make_trainable.""" - nseg_per_branch = 8 - - depth = 5 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - parents = jnp.asarray(parents) - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) - cell = jx.Cell(branch, parents=parents) + cell = SimpleCell(4, 4) cell.insert(HH()) cell.branch(0).loc(0.0).set("length", 12.0) @@ -44,17 +36,9 @@ def test_make_trainable(): cell.get_parameters() -def test_delete_trainables(): +def test_delete_trainables(SimpleCell): """Test make_trainable.""" - nseg_per_branch = 8 - - depth = 5 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - parents = jnp.asarray(parents) - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) - cell = jx.Cell(branch, parents=parents) + cell = SimpleCell(4, 4) cell.branch(0).loc(0.0).make_trainable("length", 12.0) assert cell.num_trainable_params == 1 @@ -66,17 +50,9 @@ def test_delete_trainables(): cell.get_parameters() -def test_make_trainable_network(): +def test_make_trainable_network(SimpleCell): """Test make_trainable.""" - nseg_per_branch = 8 - - depth = 5 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - parents = jnp.asarray(parents) - - comp = jx.Compartment() - branch = jx.Branch(comp, nseg_per_branch) - cell = jx.Cell(branch, parents=parents) + cell = SimpleCell(4, 4) cell.insert(HH()) net = jx.Network([cell, cell]) @@ -99,13 +75,9 @@ def test_make_trainable_network(): assert cell.num_trainable_params == 8 # `set()` is ignored. -def test_diverse_synapse_types(): +def test_diverse_synapse_types(SimpleNet): """Runs `.get_all_parameters()` and checks if the output is as expected.""" - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=1) - cell = jx.Cell(branch, parents=[-1]) - - net = jx.Network([cell for _ in range(4)]) + net = SimpleNet(4, 1, 1) for pre_ind in [0, 1]: for post_ind, syn in zip([2, 3], [IonotropicSynapse(), TestSynapse()]): pre = net.cell(pre_ind).branch(0).loc(0.0) @@ -149,9 +121,10 @@ def test_diverse_synapse_types(): assert np.all(all_parameters["IonotropicSynapse_gS"][1] == 5.5) -def test_make_all_trainable_corresponds_to_set(): +def test_make_all_trainable_corresponds_to_set(SimpleNet): # Scenario 1. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.insert(HH()) params1 = get_params_all_trainable(net1) net2.insert(HH()) @@ -159,7 +132,8 @@ def test_make_all_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) # Scenario 2. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(1).insert(HH()) params1 = get_params_all_trainable(net1) net2.cell(1).insert(HH()) @@ -167,7 +141,8 @@ def test_make_all_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) # Scenario 3. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(1).branch(0).insert(HH()) params1 = get_params_all_trainable(net1) net2.cell(1).branch(0).insert(HH()) @@ -175,7 +150,8 @@ def test_make_all_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) # Scenario 4. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(1).branch(0).loc(0.4).insert(HH()) params1 = get_params_all_trainable(net1) net2.cell(1).branch(0).loc(0.4).insert(HH()) @@ -183,9 +159,10 @@ def test_make_all_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) -def test_make_subset_trainable_corresponds_to_set(): +def test_make_subset_trainable_corresponds_to_set(SimpleNet): # Scenario 1. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.insert(HH()) params1 = get_params_subset_trainable(net1) net2.insert(HH()) @@ -193,7 +170,8 @@ def test_make_subset_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) # Scenario 2. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(0).insert(HH()) params1 = get_params_subset_trainable(net1) net2.cell(0).insert(HH()) @@ -201,7 +179,8 @@ def test_make_subset_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) # Scenario 3. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(0).branch(1).insert(HH()) params1 = get_params_subset_trainable(net1) net2.cell(0).branch(1).insert(HH()) @@ -209,7 +188,8 @@ def test_make_subset_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) # Scenario 4. - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(0).branch(1).loc(0.4).insert(HH()) params1 = get_params_subset_trainable(net1) net2.cell(0).branch(1).loc(0.4).insert(HH()) @@ -217,16 +197,13 @@ def test_make_subset_trainable_corresponds_to_set(): assert np.array_equal(params1["HH_gNa"], params2["HH_gNa"], equal_nan=True) -def test_copy_node_property_to_edges(): +def test_copy_node_property_to_edges(SimpleNet): """Test synaptic parameter sharing via `.copy_node_property_to_edges()`. This test does not explicitly use `make_trainable`, but `copy_node_property_to_edges` is an important ingredient to parameter sharing. """ - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=2) - cell = jx.Cell(branch, parents=[-1, 0]) - net = jx.Network([cell for _ in range(6)]) + net = SimpleNet(6,2,2) net.insert(HH()) net.cell(1).set("HH_gNa", 1.0) net.cell(0).set("radius", 0.2) @@ -281,15 +258,6 @@ def test_copy_node_property_to_edges(): assert np.all(edges_gna_values["post_capacitance"] == 1.0) -def build_two_networks(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1, 0]) - net1 = jx.Network([cell, cell]) - net2 = jx.Network([cell, cell]) - return net1, net2 - - def get_params_subset_trainable(net): net.cell(0).branch(1).make_trainable("HH_gNa") params = net.get_parameters() @@ -324,9 +292,10 @@ def get_params_set(net): return net.get_all_parameters(pstate, voltage_solver="jaxley.thomas") -def test_make_trainable_corresponds_to_set_pospischil(): +def test_make_trainable_corresponds_to_set_pospischil(SimpleNet): """Test whether shared parameters are also set correctly.""" - net1, net2 = build_two_networks() + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(0).insert(Na()) net1.insert(K()) net1.cell("all").branch("all").loc("all").make_trainable("vt") @@ -391,8 +360,9 @@ def build_net(): assert np.allclose(all_parameters1["radius"], all_parameters2["radius"]) -def test_data_set_vs_make_trainable_pospischil(): - net1, net2 = build_two_networks() +def test_data_set_vs_make_trainable_pospischil(SimpleNet): + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) net1.cell(0).insert(Na()) net1.insert(K()) net1.make_trainable("vt") @@ -424,8 +394,9 @@ def test_data_set_vs_make_trainable_pospischil(): assert np.max(np.abs(voltages1 - voltages2)) < 1e-8 -def test_data_set_vs_make_trainable_network(): - net1, net2 = build_two_networks() +def test_data_set_vs_make_trainable_network(SimpleNet): + net1 = SimpleNet(2, 4, 1) + net2 = SimpleNet(2, 4, 1) current = jx.step_current(0.1, 4.0, 0.1, 0.025, 5.0) for net in [net1, net2]: net.insert(HH()) @@ -468,11 +439,8 @@ def test_data_set_vs_make_trainable_network(): assert np.max(np.abs(voltages1 - voltages2)) < 1e-8 -def test_make_states_trainable_api(): - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net = jx.Network([cell for _ in range(2)]) +def test_make_states_trainable_api(SimpleNet): + net = SimpleNet(2, 2, 4) net.insert(HH()) net.cell(0).branch(0).comp(0).record() @@ -489,12 +457,9 @@ def simulate(params): assert np.invert(np.any(np.isnan(v))), "Found NaN in voltage." -def test_write_trainables(): +def test_write_trainables(SimpleNet): """Test whether `write_trainables()` gives the same result as using the trainables.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net = jx.Network([cell for _ in range(2)]) + net = SimpleNet(2, 2, 4) connect( net.cell(0).branch(0).loc(0.9), net.cell(1).branch(1).loc(0.1), diff --git a/tests/test_moving.py b/tests/test_moving.py index d8bb3d9d..6ca6c159 100644 --- a/tests/test_moving.py +++ b/tests/test_moving.py @@ -17,7 +17,7 @@ def test_move_cell(SimpleBranch, SimpleCell): # Test move on a cell with compute_xyz() - cell = SimpleCell(5, nseg=4, copy=True) + cell = SimpleCell(5, nseg=4) cell.compute_xyz() cell.move(20.0, 30.0, 5.0) assert cell.xyzr[0][0, 0] == 20.0 @@ -25,7 +25,7 @@ def test_move_cell(SimpleBranch, SimpleCell): assert cell.xyzr[0][0, 2] == 5.0 # Test move_to on a cell that starts with a specified xyzr - branch = SimpleBranch(nseg=4, copy=True) + branch = SimpleBranch(nseg=4) cell = jx.Cell( branch, parents=[-1], @@ -46,7 +46,7 @@ def test_move_cell(SimpleBranch, SimpleCell): def test_move_network(SimpleCell): - cell = SimpleCell(3, 3, copy=True) + cell = SimpleCell(3, 3) cell.compute_xyz() net = jx.Network([cell, cell, cell]) net.move(20.0, 30.0, 5.0) @@ -57,7 +57,7 @@ def test_move_network(SimpleCell): def test_move_to_cell(SimpleBranch, SimpleCell): - cell = SimpleCell(5, 4, copy=True) + cell = SimpleCell(5, 4) cell.compute_xyz() cell.move_to(20.0, 30.0, 5.0) assert cell.xyzr[0][0, 0] == 20.0 @@ -85,7 +85,7 @@ def test_move_to_cell(SimpleBranch, SimpleCell): def test_move_to_network(SimpleNet): - net = SimpleNet(3, 3, 4, copy=True) + net = SimpleNet(3, 3, 4) net.compute_xyz() net.move_to(10.0, 20.0, 30.0) # Branch 0 of cell 0 @@ -101,7 +101,7 @@ def test_move_to_network(SimpleNet): def test_move_to_arrays(SimpleNet): """Test with network""" nseg = 4 - net = SimpleNet(3, 3, nseg, copy=True) + net = SimpleNet(3, 3, nseg) net.compute_xyz() x_coords = np.array([10.0, 20.0, 30.0]) y_coords = np.array([5.0, 15.0, 25.0]) @@ -117,7 +117,7 @@ def test_move_to_arrays(SimpleNet): def test_move_to_cellview(net): - net = net(3, 3, 2, copy=True) + net = net(3, 3, 2) net.compute_xyz() # Test with float input @@ -127,7 +127,7 @@ def test_move_to_cellview(net): assert net.xyzr[0][0, 2] == 40.0 # Test with array input - net = net(4, 3, 2, copy=True) + net = net(4, 3, 2) net.compute_xyz() testx = np.array([1.0, 2.0, 3.0]) testy = np.array([4.0, 5.0, 6.0]) @@ -142,9 +142,9 @@ def test_move_to_cellview(net): def test_move_to_swc_cell(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = SimpleMorphCell(fname, nseg=4, copy=True) - cell2 = SimpleMorphCell(fname, nseg=4, copy=True) - cell3 = SimpleMorphCell(fname, nseg=4, copy=True) + cell1 = SimpleMorphCell(fname, nseg=4) + cell2 = SimpleMorphCell(fname, nseg=4) + cell3 = SimpleMorphCell(fname, nseg=4) # Try move_to on a cell cell1.move_to(10.0, 20.0, 30.0) diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 5d476add..25186a13 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -18,7 +18,7 @@ def test_record_and_stimulate_api(SimpleCell): """Test the API for recording and stimulating.""" - cell = SimpleCell(3, 2, copy=True) + cell = SimpleCell(3, 2) cell.branch(0).loc(0.0).record() cell.branch(1).loc(1.0).record() @@ -32,7 +32,7 @@ def test_record_and_stimulate_api(SimpleCell): def test_record_shape(SimpleCell): """Test the API for recording and stimulating.""" - cell = SimpleCell(3, 2, copy=True) + cell = SimpleCell(3, 2) current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) cell.branch(1).loc(1.0).stimulate(current) @@ -59,7 +59,7 @@ def test_record_synaptic_and_membrane_states(SimpleNet): _ = np.random.seed(0) # Seed because connectivity is at random postsyn locs. - net = SimpleNet(3, 1, 4, copy=True) + net = SimpleNet(3, 1, 4) net.insert(HH()) fully_connect(net.cell([0]), net.cell([1]), IonotropicSynapse()) @@ -107,9 +107,9 @@ def test_record_synaptic_and_membrane_states(SimpleNet): assert np.all(np.abs(maxima_3 - maxima_1 - offset_mem)) < 5.0 -def test_empty_recordings(): +def test_empty_recordings(SimpleComp): # Create an empty compartment - comp = jx.Compartment() + comp = SimpleComp() # Check if a ValueError is raised when integrating an empty compartment with pytest.raises(ValueError): diff --git a/tests/test_set_ncomp.py b/tests/test_set_ncomp.py index f154ef62..eeed09f8 100644 --- a/tests/test_set_ncomp.py +++ b/tests/test_set_ncomp.py @@ -20,8 +20,8 @@ "property", ["radius", "capacitance", "length", "axial_resistivity"] ) def test_raise_for_heterogenous_modules(property, SimpleBranch): - branch0 = SimpleBranch(4, copy=True) - branch1 = SimpleBranch(4, copy=True) + branch0 = SimpleBranch(4) + branch1 = SimpleBranch(4) branch1.comp(1).set(property, 1.5) cell = jx.Cell([branch0, branch1], parents=[-1, 0]) with pytest.raises(ValueError): @@ -29,8 +29,8 @@ def test_raise_for_heterogenous_modules(property, SimpleBranch): def test_raise_for_heterogenous_channel_existance(SimpleBranch): - branch0 = SimpleBranch(4, copy=True) - branch1 = SimpleBranch(4, copy=True) + branch0 = SimpleBranch(4) + branch1 = SimpleBranch(4) branch1.comp(2).insert(HH()) cell = jx.Cell([branch0, branch1], parents=[-1, 0]) with pytest.raises(ValueError): @@ -38,8 +38,8 @@ def test_raise_for_heterogenous_channel_existance(SimpleBranch): def test_raise_for_heterogenous_channel_properties(SimpleBranch): - branch0 = SimpleBranch(4, copy=True) - branch1 = SimpleBranch(4, copy=True) + branch0 = SimpleBranch(4) + branch1 = SimpleBranch(4) branch1.insert(HH()) branch1.comp(3).set("HH_gNa", 0.5) cell = jx.Cell([branch0, branch1], parents=[-1, 0]) @@ -62,14 +62,14 @@ def test_raise_for_networks(SimpleCell): def test_raise_for_recording(SimpleCell): - cell = SimpleCell(3, 2, copy=True) + cell = SimpleCell(3, 2) cell.branch(0).comp(0).record() with pytest.raises(AssertionError): cell.branch(1).set_ncomp(2) def test_raise_for_stimulus(SimpleCell): - cell = SimpleCell(3, 2, copy=True) + cell = SimpleCell(3, 2) cell.branch(0).comp(0).stimulate(0.4 * jnp.ones(100)) with pytest.raises(AssertionError): cell.branch(1).set_ncomp(2) @@ -83,11 +83,11 @@ def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch( This makes one branch, whose `ncomp` is not modified, heterogenous. """ - branch1 = SimpleBranch(new_ncomp, copy=True) + branch1 = SimpleBranch(new_ncomp) # The second branch is originally instantiated to have 4 ncomp, but is later # modified to have `new_ncomp` compartments. - branch2 = SimpleBranch(4, copy=True) + branch2 = SimpleBranch(4) branch2.comp("all").set("length", 10.0) total_branch_len = 4 * 10.0 @@ -112,11 +112,11 @@ def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell( new_ncomp, SimpleBranch ): """Test whether a module built from scratch matches module built with `set_ncomp()`.""" - branch1 = SimpleBranch(new_ncomp, copy=True) + branch1 = SimpleBranch(new_ncomp) # The second branch is originally instantiated to have 4 ncomp, but is later # modified to have `new_ncomp` compartments. - branch2 = SimpleBranch(4, copy=True) + branch2 = SimpleBranch(4) branch2.comp("all").set("length", 10.0) total_branch_len = 4 * 10.0 @@ -146,8 +146,8 @@ def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file, SimpleMorphCe dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) - cell1 = SimpleMorphCell(fname, nseg=new_ncomp, copy=True) - cell2 = SimpleMorphCell(fname, nseg=4, copy=True) + cell1 = SimpleMorphCell(fname, nseg=new_ncomp) + cell2 = SimpleMorphCell(fname, nseg=4) for b in range(cell2.total_nbranches): cell2.branch(b).set_ncomp(new_ncomp) @@ -167,8 +167,8 @@ def test_simulation_accuracy_swc_init_vs_set_ncomp(new_ncomp, file, SimpleMorphC dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) - cell1 = SimpleMorphCell(fname, nseg=new_ncomp, copy=True) - cell2 = SimpleMorphCell(fname, nseg=4, copy=True) + cell1 = SimpleMorphCell(fname, nseg=new_ncomp) + cell2 = SimpleMorphCell(fname, nseg=4) for b in range(cell2.total_nbranches): cell2.branch(b).set_ncomp(new_ncomp) diff --git a/tests/test_solver.py b/tests/test_solver.py index be42f8d5..88a86c59 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -28,7 +28,7 @@ def test_fwd_euler_and_crank_nicolson(SimpleNet): Tests whether forward Euler and Crank-Nicolson are sufficiently close to implicit Euler.""" - net = SimpleNet(2, 1, 4, connect=True, copy=True) + net = SimpleNet(2, 1, 4, connect=True) current = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0) net.cell(0).branch(0).comp(0).stimulate(current) diff --git a/tests/test_swc.py b/tests/test_swc.py index 52757b99..53393aa8 100644 --- a/tests/test_swc.py +++ b/tests/test_swc.py @@ -23,11 +23,11 @@ # Test is failing for "morph.swc". This is because NEURON and Jaxley handle interrupted # soma differently, see issue #140. @pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph_minimal.swc"]) -def test_swc_reader_lengths(file): +def test_swc_reader_lengths(file, swc2jaxley): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) - _, pathlengths, _, _, _ = jx.utils.swc.swc_to_jaxley(fname, max_branch_len=2000.0) + _, pathlengths, _, _, _ = swc2jaxley(fname, max_branch_len=2000.0) if pathlengths[0] == 0.1: pathlengths = pathlengths[1:] @@ -53,19 +53,17 @@ def test_swc_reader_lengths(file): ), "Number of branches does not match." -def test_dummy_compartment_length(): +def test_dummy_compartment_length(swc2jaxley): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph_soma_both_ends.swc") - parents, pathlengths, _, _, _ = jx.utils.swc.swc_to_jaxley( - fname, max_branch_len=2000.0 - ) + parents, pathlengths, _, _, _ = swc2jaxley(fname, max_branch_len=2000.0) assert parents == [-1, 0, 0, 1] assert pathlengths == [0.1, 1.0, 2.6, 2.2] @pytest.mark.parametrize("file", ["morph_250_single_point_soma.swc", "morph_250.swc"]) -def test_swc_radius(file): +def test_swc_radius(file, swc2jaxley): """We expect them to match for sufficiently large nseg. See #140.""" nseg = 64 non_split = 1 / nseg @@ -75,9 +73,7 @@ def test_swc_radius(file): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) - _, pathlen, radius_fns, _, _ = jx.utils.swc.swc_to_jaxley( - fname, max_branch_len=2000.0, sort=False - ) + _, pathlen, radius_fns, _, _ = swc2jaxley(fname, max_branch_len=2000.0, sort=False) jaxley_diams = [] for r in radius_fns: jaxley_diams.append(r(range_16) * 2) @@ -105,7 +101,7 @@ def test_swc_radius(file): @pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"]) -def test_swc_voltages(file, SimpleMorphCell): +def test_swc_voltages(file, SimpleMorphCell, swc2jaxley): """Check if voltages of SWC recording match. To match the branch indices between NEURON and jaxley, we rely on comparing the @@ -142,8 +138,8 @@ def test_swc_voltages(file, SimpleMorphCell): pathlengths_neuron = np.asarray([sec.L for sec in h.allsec()]) ####################### jaxley ################## - _, pathlengths, _, _, _ = jx.utils.swc.swc_to_jaxley(fname, max_branch_len=2_000) - cell = SimpleMorphCell(fname, nseg_per_branch, max_branch_len=2_000.0, copy=True) + _, pathlengths, _, _, _ = swc2jaxley(fname, max_branch_len=2_000) + cell = SimpleMorphCell(fname, nseg_per_branch, max_branch_len=2_000.0) cell.insert(HH()) trunk_inds = [1, 4, 5, 13, 15, 21, 23, 24, 29, 33] diff --git a/tests/test_syn.py b/tests/test_syn.py index 656bfa42..3159e036 100644 --- a/tests/test_syn.py +++ b/tests/test_syn.py @@ -18,7 +18,7 @@ def test_set_and_querying_params_one_type(SimpleNet): """Test if the correct parameters are set if one type of synapses is inserted.""" - net = SimpleNet(4, 1, 4, copy=True) + net = SimpleNet(4, 1, 4) for pre_ind in [0, 1]: for post_ind in [2, 3]: diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index 69b4ece8..150a5d83 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -23,7 +23,7 @@ def test_multiparameter_setting(SimpleNet): Tests global index dropping: d4daaf019596589b9430219a15f1dda0b1c34d85 """ - net = SimpleNet(2, 1, 4, copy=True) + net = SimpleNet(2, 1, 4) pre = net.cell(0).branch(0).loc(0.0) post = net.cell(1).branch(0).loc(0.0) @@ -59,7 +59,7 @@ def _get_synapse_view(net, synapse_name, single_idx=1, double_idxs=[2, 3]): def test_set_and_querying_params_one_type(synapse_type, SimpleNet): """Test if the correct parameters are set if one type of synapses is inserted.""" synapse_type = synapse_type() - net = SimpleNet(4, 1, 4, copy=True) + net = SimpleNet(4, 1, 4) for pre_ind in [0, 1]: for post_ind in [2, 3]: @@ -97,7 +97,7 @@ def test_set_and_querying_params_one_type(synapse_type, SimpleNet): def test_set_and_querying_params_two_types(synapse_type, SimpleNet): """Test whether the correct parameters are set.""" synapse_type = synapse_type() - net = SimpleNet(4, 1, 4, copy=True) + net = SimpleNet(4, 1, 4) for pre_ind in [0, 1]: for post_ind, synapse in zip([2, 3], [IonotropicSynapse(), synapse_type]): @@ -154,8 +154,8 @@ def test_shuffling_order_of_set(synapse_type, SimpleNet): """Test whether the result is the same if the order of synapses is changed.""" synapse_type = synapse_type() - net1 = SimpleNet(4, 1, 4, copy=True) - net2 = SimpleNet(4, 1, 4, copy=True) + net1 = SimpleNet(4, 1, 4) + net2 = SimpleNet(4, 1, 4) connect( net1.cell(0).branch(0).loc(1.0), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a5672a68..af227542 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -152,7 +152,7 @@ def test_correct(transform): [jt.SigmoidTransform(-2, 2), jt.SoftplusTransform(2), jt.NegSoftplusTransform(2)], ) def test_user_api(transform, SimpleCell): - cell = SimpleCell(3, 2, copy=True) + cell = SimpleCell(3, 2) cell.branch("all").make_trainable("radius") cell.branch(2).make_trainable("radius") diff --git a/tests/test_viewing.py b/tests/test_viewing.py index cb75130b..ce7a9440 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -341,7 +341,7 @@ def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, Simple def test_select(SimpleNet): """Ensure `select` works correctly and returns expected View of Modules.""" - net = SimpleNet(3, 3, 2, connect=False, copy=True) + net = SimpleNet(3, 3, 2, connect=False) connect(net[0, 0, :], net[1, 0, :], TestSynapse()) np.random.seed(0) @@ -546,7 +546,7 @@ def test_iter(SimpleBranch): def test_synapse_and_channel_filtering(SimpleNet): """Test that synapses and channels are filtered correctly by View.""" - net = SimpleNet(3, 3, 3, connect=False, copy=True) + net = SimpleNet(3, 3, 3, connect=False) net.insert(HH()) connect(net[0, 0, :], net[1, 0, :], TestSynapse()) @@ -573,7 +573,7 @@ def test_synapse_and_channel_filtering(SimpleNet): def test_view_equals_module(SimpleComp, SimpleBranch): """Test that View behaves the same as Module for important attrs and methods.""" comp = SimpleComp(copy=True) - branch = SimpleBranch(3, copy=True) + branch = SimpleBranch(3) comp.insert(HH()) branch.comp([0, 1]).insert(HH()) From 5de51f0b84916b7bdf87e914661e664f7437855e Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 14 Nov 2024 20:12:19 +0100 Subject: [PATCH 03/15] fix: ran black after rebase --- tests/test_make_trainable.py | 2 +- tests/test_plotting_api.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 26fc6021..1a84bc11 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -203,7 +203,7 @@ def test_copy_node_property_to_edges(SimpleNet): This test does not explicitly use `make_trainable`, but `copy_node_property_to_edges` is an important ingredient to parameter sharing. """ - net = SimpleNet(6,2,2) + net = SimpleNet(6, 2, 2) net.insert(HH()) net.cell(1).set("HH_gNa", 1.0) net.cell(0).set("radius", 0.2) diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index 57e7a121..7600496e 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -35,6 +35,7 @@ def test_cell(SimpleMorphCell): cell.branch(1).add_to_group("soma") ax = cell.soma.vis() + def test_network(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") @@ -159,7 +160,7 @@ def test_volume_plotting(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): comp = SimpleComp() branch = SimpleBranch(4) cell = SimpleCell(3, 4) - net = SimpleNet(2,3,4) + net = SimpleNet(2, 3, 4) for module in [comp, branch, cell, net]: module.compute_xyz() @@ -183,4 +184,4 @@ def test_volume_plotting(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): morph_cell.branch(1).vis( type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6} ) # plotting whole thing takes too long - plt.close() \ No newline at end of file + plt.close() From b8aa8270172557f5023e6b1f333066885ba01cba Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 14 Nov 2024 20:30:35 +0100 Subject: [PATCH 04/15] fix: make tests pass --- tests/test_api_equivalence.py | 4 ++-- tests/test_cell_matches_branch.py | 12 +++++------- tests/test_moving.py | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index fa11f839..9751a331 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -103,12 +103,12 @@ def test_solver_backends_net(SimpleNet): connect( net.cell(0).branch(0).loc(1.0), - net.cell(1).branch(4).loc(1.0), + net.cell(1).branch(3).loc(1.0), IonotropicSynapse(), ) connect( net.cell(1).branch(1).loc(0.8), - net.cell(0).branch(4).loc(0.1), + net.cell(0).branch(3).loc(0.1), IonotropicSynapse(), ) diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index a1a3d274..c5eeae02 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -15,8 +15,7 @@ from jaxley.channels import HH -def _run_long_branch(dt, t_max, SimpleBranch): - branch = SimpleBranch(8) +def _run_long_branch(dt, t_max, branch): branch.insert(HH()) branch.loc("all").make_trainable("radius", 1.0) @@ -35,8 +34,7 @@ def loss(params): return l, g -def _run_short_branches(dt, t_max, SimpleCell): - cell = SimpleCell(2, 4) +def _run_short_branches(dt, t_max, cell): cell.insert(HH()) cell.branch("all").loc("all").make_trainable("radius", 1.0) @@ -55,12 +53,12 @@ def loss(params): return l, g -def test_equivalence(): +def test_equivalence(SimpleBranch, SimpleCell): """Test whether a single long branch matches a cell of two shorter branches.""" dt = 0.025 t_max = 5.0 # ms - l1, g1 = _run_long_branch(dt, t_max) - l2, g2 = _run_short_branches(dt, t_max) + l1, g1 = _run_long_branch(dt, t_max, SimpleBranch(8)) + l2, g2 = _run_short_branches(dt, t_max, SimpleCell(2, 4)) assert np.allclose(l1, l2), "Losses do not match." diff --git a/tests/test_moving.py b/tests/test_moving.py index 6ca6c159..27bb8dc4 100644 --- a/tests/test_moving.py +++ b/tests/test_moving.py @@ -116,8 +116,8 @@ def test_move_to_arrays(SimpleNet): assert net.xyzr[6][0, 1] == 25.0 -def test_move_to_cellview(net): - net = net(3, 3, 2) +def test_move_to_cellview(SimpleNet): + net = SimpleNet(3, 3, 2) net.compute_xyz() # Test with float input @@ -127,7 +127,7 @@ def test_move_to_cellview(net): assert net.xyzr[0][0, 2] == 40.0 # Test with array input - net = net(4, 3, 2) + net = SimpleNet(4, 3, 2) net.compute_xyz() testx = np.array([1.0, 2.0, 3.0]) testy = np.array([4.0, 5.0, 6.0]) From 60cfe760374235e1af827d356c3e4c1a781f531d Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 15 Nov 2024 13:05:01 +0100 Subject: [PATCH 05/15] enh: add slow markers and add remaining fixtures --- pyproject.toml | 5 ++ tests/conftest.py | 2 +- tests/jaxley_identical/test_basic_modules.py | 77 +++++++------------ tests/jaxley_identical/test_grad.py | 10 +-- .../test_radius_and_length.py | 62 ++++++--------- tests/jaxley_identical/test_swc.py | 12 +-- tests/test_api_equivalence.py | 2 + tests/test_cell_matches_branch.py | 8 +- tests/test_channels.py | 1 + tests/test_moving.py | 6 +- tests/test_plotting_api.py | 28 +++---- tests/test_set_ncomp.py | 4 +- 12 files changed, 94 insertions(+), 123 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e382acbb..d37e2bda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,11 @@ dev = [ "jupyter", ] +[tool.pytest.ini_options] +markers = [ + "slow: marks tests as slow (T > 10s)", +] + [tool.isort] profile = "black" diff --git a/tests/conftest.py b/tests/conftest.py index 8cf46ac5..55f9b8b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,7 +79,7 @@ def SimpleMorphCell(): cells = {} - def cell_w_params(fname=None, nseg=2, max_branch_len=2_000.0, copy=True): + def cell_w_params(fname=None, nseg=1, max_branch_len=2_000.0, copy=True): fname = default_fname if fname is None else fname if key := (fname, nseg, max_branch_len) not in cells: cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True) diff --git a/tests/jaxley_identical/test_basic_modules.py b/tests/jaxley_identical/test_basic_modules.py index 8faba4b3..577878c2 100644 --- a/tests/jaxley_identical/test_basic_modules.py +++ b/tests/jaxley_identical/test_basic_modules.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_compartment(voltage_solver: str): +def test_compartment(voltage_solver, SimpleComp, SimpleBranch, SimpleCell, SimpleNet): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) @@ -48,7 +48,7 @@ def test_compartment(voltage_solver: str): ) # Test compartment. - comp = jx.Compartment() + comp = SimpleComp() comp.insert(HH()) comp.record() comp.stimulate(current) @@ -57,7 +57,7 @@ def test_compartment(voltage_solver: str): assert max_error <= tolerance, f"Compartment error is {max_error} > {tolerance}" # Test branch of a single compartment. - branch = jx.Branch() + branch = SimpleBranch(nseg=1) branch.insert(HH()) branch.record() branch.stimulate(current) @@ -66,7 +66,7 @@ def test_compartment(voltage_solver: str): assert max_error <= tolerance, f"Branch error is {max_error} > {tolerance}" # Test cell of a single compartment. - cell = jx.Cell() + cell = SimpleCell(1, 1) cell.insert(HH()) cell.record() cell.stimulate(current) @@ -75,8 +75,7 @@ def test_compartment(voltage_solver: str): assert max_error <= tolerance, f"Cell error is {max_error} > {tolerance}" # Test net of a single compartment. - cell = jx.Cell() - net = jx.Network([cell]) + net = SimpleNet(1, 1, 1) net.insert(HH()) net.record() net.stimulate(current) @@ -86,14 +85,12 @@ def test_compartment(voltage_solver: str): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_branch(voltage_solver: str): - nseg_per_branch = 2 +def test_branch(voltage_solver, SimpleBranch): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) + branch = SimpleBranch(2) branch.insert(HH()) branch.loc(0.0).record() branch.loc(0.0).stimulate(current) @@ -122,13 +119,12 @@ def test_branch(voltage_solver: str): assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" -def test_branch_fwd_euler_uneven_radiuses(): +def test_branch_fwd_euler_uneven_radiuses(SimpleBranch): dt = 0.025 # ms t_max = 10.0 # ms current = jx.step_current(0.5, 1.0, 2.0, dt, t_max) - comp = jx.Compartment() - branch = jx.Branch(comp, 8) + branch = SimpleBranch(8) branch.set("axial_resistivity", 500.0) rands1 = np.linspace(20, 300, 8) @@ -161,18 +157,12 @@ def test_branch_fwd_euler_uneven_radiuses(): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_cell(voltage_solver: str): - nseg_per_branch = 2 +def test_cell(voltage_solver, SimpleCell): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - depth = 2 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) - cell = jx.Cell([branch for _ in range(len(parents))], parents=parents) + cell = SimpleCell(3, 2) cell.insert(HH()) cell.branch(1).loc(0.0).record() cell.branch(1).loc(0.0).stimulate(current) @@ -201,17 +191,16 @@ def test_cell(voltage_solver: str): assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" -def test_cell_unequal_compartment_number(): +def test_cell_unequal_compartment_number(SimpleBranch): """Tests a cell where every branch has a different number of compartments.""" dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.1, dt, t_max) - comp = jx.Compartment() - branch1 = jx.Branch(comp, nseg=1) - branch2 = jx.Branch(comp, nseg=2) - branch3 = jx.Branch(comp, nseg=3) - branch4 = jx.Branch(comp, nseg=4) + branch1 = SimpleBranch(nseg=1) + branch2 = SimpleBranch(nseg=2) + branch3 = SimpleBranch(nseg=3) + branch4 = SimpleBranch(nseg=4) cell = jx.Cell([branch1, branch2, branch3, branch4], parents=[-1, 0, 0, 1]) cell.set("axial_resistivity", 10_000.0) cell.insert(HH()) @@ -236,40 +225,32 @@ def test_cell_unequal_compartment_number(): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_net(voltage_solver: str): - nseg_per_branch = 2 +def test_net(voltage_solver, SimpleNet): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - depth = 2 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) - cell1 = jx.Cell([branch for _ in range(len(parents))], parents=parents) - cell2 = jx.Cell([branch for _ in range(len(parents))], parents=parents) + net = SimpleNet(2, 3, 2) - network = jx.Network([cell1, cell2]) connect( - network.cell(0).branch(0).loc(0.0), - network.cell(1).branch(0).loc(0.0), + net.cell(0).branch(0).loc(0.0), + net.cell(1).branch(0).loc(0.0), IonotropicSynapse(), ) - network.insert(HH()) + net.insert(HH()) for cell_ind in range(2): - network.cell(cell_ind).branch(1).loc(0.0).record() + net.cell(cell_ind).branch(1).loc(0.0).record() for stim_ind in range(2): - network.cell(stim_ind).branch(1).loc(0.0).stimulate(current) + net.cell(stim_ind).branch(1).loc(0.0).stimulate(current) area = 2 * pi * 10.0 * 1.0 point_process_to_dist_factor = 100_000.0 / area - network.IonotropicSynapse.set( + net.IonotropicSynapse.set( "IonotropicSynapse_gS", 0.5 / point_process_to_dist_factor ) - voltages = jx.integrate(network, delta_t=dt, voltage_solver=voltage_solver) + voltages = jx.integrate(net, delta_t=dt, voltage_solver=voltage_solver) voltages_300724 = jnp.asarray( [ @@ -308,12 +289,8 @@ def test_net(voltage_solver: str): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_complex_net(voltage_solver: str): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) - - net = jx.Network([cell for _ in range(7)]) +def test_complex_net(voltage_solver, SimpleNet): + net = SimpleNet(7, 5, 4) net.insert(HH()) _ = np.random.seed(0) diff --git a/tests/jaxley_identical/test_grad.py b/tests/jaxley_identical/test_grad.py index 198201bc..b6c98d30 100644 --- a/tests/jaxley_identical/test_grad.py +++ b/tests/jaxley_identical/test_grad.py @@ -14,6 +14,7 @@ import jax.numpy as jnp import numpy as np +import pytest from jax import value_and_grad import jaxley as jx @@ -22,12 +23,9 @@ from jaxley.synapses import IonotropicSynapse, TestSynapse -def test_network_grad(): - comp = jx.Compartment() - branch = jx.Branch(comp, nseg=4) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) - - net = jx.Network([cell for _ in range(7)]) +@pytest.mark.slow +def test_network_grad(SimpleNet): + net = SimpleNet(7, 5, 4) net.insert(HH()) _ = np.random.seed(0) diff --git a/tests/jaxley_identical/test_radius_and_length.py b/tests/jaxley_identical/test_radius_and_length.py index c68a3e5f..f0ef3c83 100644 --- a/tests/jaxley_identical/test_radius_and_length.py +++ b/tests/jaxley_identical/test_radius_and_length.py @@ -23,12 +23,12 @@ @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_radius_and_length_compartment(voltage_solver: str): +def test_radius_and_length_compartment(voltage_solver, SimpleComp): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - comp = jx.Compartment() + comp = SimpleComp() np.random.seed(1) comp.set("length", 5 * np.random.rand(1)) @@ -63,14 +63,12 @@ def test_radius_and_length_compartment(voltage_solver: str): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_radius_and_length_branch(voltage_solver: str): - nseg_per_branch = 2 +def test_radius_and_length_branch(voltage_solver, SimpleBranch): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) + branch = SimpleBranch(nseg=2) np.random.seed(1) branch.set("length", np.flip(5 * np.random.rand(2))) @@ -105,19 +103,13 @@ def test_radius_and_length_branch(voltage_solver: str): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_radius_and_length_cell(voltage_solver: str): - nseg_per_branch = 2 +def test_radius_and_length_cell(voltage_solver, SimpleCell): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - depth = 2 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - num_branches = len(parents) - - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) - cell = jx.Cell([branch for _ in range(len(parents))], parents=parents) + num_branches = 3 + cell = SimpleCell(num_branches, nseg=2) np.random.seed(1) rands1 = 5 * np.random.rand(2 * num_branches) @@ -155,57 +147,49 @@ def test_radius_and_length_cell(voltage_solver: str): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_radius_and_length_net(voltage_solver: str): - nseg_per_branch = 2 +def test_radius_and_length_net(voltage_solver, SimpleNet): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) - depth = 2 - parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] - num_branches = len(parents) - - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) - cell1 = jx.Cell([branch for _ in range(len(parents))], parents=parents) - cell2 = jx.Cell([branch for _ in range(len(parents))], parents=parents) + num_branches = 3 + net = SimpleNet(2, num_branches, 2) np.random.seed(1) rands1 = 5 * np.random.rand(2 * num_branches) rands2 = np.random.rand(2 * num_branches) for b in range(num_branches): - cell1.branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2])) - cell1.branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2])) + net.cell(0).branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2])) + net.cell(0).branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2])) np.random.seed(2) rands1 = 5 * np.random.rand(2 * num_branches) rands2 = np.random.rand(2 * num_branches) for b in range(num_branches): - cell2.branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2])) - cell2.branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2])) + net.cell(1).branch(b).set("length", np.flip(rands1[2 * b : 2 * b + 2])) + net.cell(1).branch(b).set("radius", np.flip(rands2[2 * b : 2 * b + 2])) - network = jx.Network([cell1, cell2]) connect( - network.cell(0).branch(0).loc(0.0), - network.cell(1).branch(0).loc(0.0), + net.cell(0).branch(0).loc(0.0), + net.cell(1).branch(0).loc(0.0), IonotropicSynapse(), ) - network.insert(HH()) + net.insert(HH()) # first cell, 0-eth branch, 0-st compartment because loc=0.0 - radius_post = network[1, 0, 0].nodes["radius"].item() - lenght_post = network[1, 0, 0].nodes["length"].item() + radius_post = net[1, 0, 0].nodes["radius"].item() + lenght_post = net[1, 0, 0].nodes["length"].item() area = 2 * pi * lenght_post * radius_post point_process_to_dist_factor = 100_000.0 / area - network.set("IonotropicSynapse_gS", 0.5 / point_process_to_dist_factor) + net.set("IonotropicSynapse_gS", 0.5 / point_process_to_dist_factor) for cell_ind in range(2): - network.cell(cell_ind).branch(1).loc(0.0).record() + net.cell(cell_ind).branch(1).loc(0.0).record() for stim_ind in range(2): - network.cell(stim_ind).branch(1).loc(0.0).stimulate(current) + net.cell(stim_ind).branch(1).loc(0.0).stimulate(current) - voltages = jx.integrate(network, delta_t=dt, voltage_solver=voltage_solver) + voltages = jx.integrate(net, delta_t=dt, voltage_solver=voltage_solver) voltages_300724 = jnp.asarray( [ diff --git a/tests/jaxley_identical/test_swc.py b/tests/jaxley_identical/test_swc.py index b2773a3b..eccdfd19 100644 --- a/tests/jaxley_identical/test_swc.py +++ b/tests/jaxley_identical/test_swc.py @@ -21,16 +21,17 @@ from jaxley.synapses import IonotropicSynapse +@pytest.mark.slow @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) @pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"]) -def test_swc_cell(voltage_solver: str, file: str): +def test_swc_cell(voltage_solver: str, file: str, SimpleMorphCell): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.2, dt, t_max) dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "../swc_files", file) - cell = jx.read_swc(fname, nseg=2, max_branch_len=300.0, assign_groups=True) + cell = SimpleMorphCell(fname, nseg=2, max_branch_len=300.0) _ = cell.soma # Only to test whether the `soma` group was created. cell.insert(HH()) cell.branch(1).loc(0.0).record() @@ -81,16 +82,17 @@ def test_swc_cell(voltage_solver: str, file: str): assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" +@pytest.mark.slow @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) -def test_swc_net(voltage_solver: str): +def test_swc_net(voltage_solver: str, SimpleMorphCell): dt = 0.025 # ms t_max = 5.0 # ms current = jx.step_current(0.5, 1.0, 0.2, dt, t_max) dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "../swc_files/morph.swc") - cell1 = jx.read_swc(fname, nseg=2, max_branch_len=300.0) - cell2 = jx.read_swc(fname, nseg=2, max_branch_len=300.0) + cell1 = SimpleMorphCell(fname, nseg=2, max_branch_len=300.0) + cell2 = SimpleMorphCell(fname, nseg=2, max_branch_len=300.0) network = jx.Network([cell1, cell2]) connect( diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index 9751a331..be0c1938 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numpy as np +import pytest import jaxley as jx from jaxley.channels import HH @@ -80,6 +81,7 @@ def test_solver_backends_branch(SimpleBranch): assert max_error < 1e-8, f"{message} thomas/stone. Error={max_error}" +@pytest.mark.slow def test_solver_backends_cell(SimpleCell): """Test whether ways of adding synapses are equivalent.""" cell = SimpleCell(4, 4) diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index c5eeae02..a2375513 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import numpy as np +import pytest from jax import jit, value_and_grad import jaxley as jx @@ -22,7 +23,7 @@ def _run_long_branch(dt, t_max, branch): params = branch.get_parameters() branch.loc(0.0).record() - branch.loc(0.0).stimulate(jx.step_current(0.5, 5.0, 0.1, dt, t_max)) + branch.loc(0.0).stimulate(jx.step_current(0.2, 2.0, 0.1, dt, t_max)) def loss(params): s = jx.integrate(branch, params=params) @@ -41,7 +42,7 @@ def _run_short_branches(dt, t_max, cell): params = cell.get_parameters() cell.branch(0).loc(0.0).record() - cell.branch(0).loc(0.0).stimulate(jx.step_current(0.5, 5.0, 0.1, dt, t_max)) + cell.branch(0).loc(0.0).stimulate(jx.step_current(0.2, 2.0, 0.1, dt, t_max)) def loss(params): s = jx.integrate(cell, params=params) @@ -53,10 +54,11 @@ def loss(params): return l, g +@pytest.mark.slow def test_equivalence(SimpleBranch, SimpleCell): """Test whether a single long branch matches a cell of two shorter branches.""" dt = 0.025 - t_max = 5.0 # ms + t_max = 2.0 # ms l1, g1 = _run_long_branch(dt, t_max, SimpleBranch(8)) l2, g2 = _run_short_branches(dt, t_max, SimpleCell(2, 4)) diff --git a/tests/test_channels.py b/tests/test_channels.py index 5870a10f..e4722334 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -164,6 +164,7 @@ def test_integration_with_renamed_channels(): assert np.invert(np.any(np.isnan(v))) +@pytest.mark.slow def test_init_states(SimpleCell): """Functional test for `init_states()`. diff --git a/tests/test_moving.py b/tests/test_moving.py index 27bb8dc4..e0ef0403 100644 --- a/tests/test_moving.py +++ b/tests/test_moving.py @@ -142,9 +142,9 @@ def test_move_to_cellview(SimpleNet): def test_move_to_swc_cell(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = SimpleMorphCell(fname, nseg=4) - cell2 = SimpleMorphCell(fname, nseg=4) - cell3 = SimpleMorphCell(fname, nseg=4) + cell1 = SimpleMorphCell(fname, nseg=1) + cell2 = SimpleMorphCell(fname, nseg=1) + cell3 = SimpleMorphCell(fname, nseg=1) # Try move_to on a cell cell1.move_to(10.0, 20.0, 30.0) diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index 7600496e..2193bd31 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -22,7 +22,7 @@ def test_cell(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") - cell = SimpleMorphCell(fname, nseg=4) + cell = SimpleMorphCell(fname, nseg=1) # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) @@ -39,9 +39,9 @@ def test_cell(SimpleMorphCell): def test_network(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = SimpleMorphCell(fname, nseg=4) - cell2 = SimpleMorphCell(fname, nseg=4) - cell3 = SimpleMorphCell(fname, nseg=4) + cell1 = SimpleMorphCell(fname, nseg=1) + cell2 = SimpleMorphCell(fname, nseg=1) + cell3 = SimpleMorphCell(fname, nseg=1) net = jx.Network([cell1, cell2, cell3]) connect( @@ -122,7 +122,7 @@ def test_vis_networks_built_from_scartch(SimpleComp, SimpleBranch, SimpleCell): def test_mixed_network(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = SimpleMorphCell(fname, nseg=4) + cell1 = SimpleMorphCell(fname, nseg=1) comp = jx.Compartment() branch = jx.Branch(comp, 4) @@ -156,18 +156,18 @@ def test_mixed_network(SimpleMorphCell): _ = net.vis(detail="full") -def test_volume_plotting(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): +def test_volume_plotting( + SimpleComp, SimpleBranch, SimpleCell, SimpleNet, SimpleMorphCell +): comp = SimpleComp() - branch = SimpleBranch(4) - cell = SimpleCell(3, 4) - net = SimpleNet(2, 3, 4) + branch = SimpleBranch(2) + cell = SimpleCell(2, 2) + net = SimpleNet(2, 2, 2) for module in [comp, branch, cell, net]: module.compute_xyz() - morph_cell = jx.read_swc( - os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"), - nseg=1, - ) + fname = os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc") + morph_cell = SimpleMorphCell(fname, nseg=1) fig, ax = plt.subplots() for module in [comp, branch, cell, net, morph_cell]: @@ -180,7 +180,7 @@ def test_volume_plotting(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): plt.close() # test morph plotting (does not work if no radii in xyzr) - morph_cell.vis(type="morph") + morph_cell.branch(1).vis(type="morph") morph_cell.branch(1).vis( type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6} ) # plotting whole thing takes too long diff --git a/tests/test_set_ncomp.py b/tests/test_set_ncomp.py index eeed09f8..53dbb200 100644 --- a/tests/test_set_ncomp.py +++ b/tests/test_set_ncomp.py @@ -147,7 +147,7 @@ def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file, SimpleMorphCe fname = os.path.join(dirname, "swc_files", file) cell1 = SimpleMorphCell(fname, nseg=new_ncomp) - cell2 = SimpleMorphCell(fname, nseg=4) + cell2 = SimpleMorphCell(fname, nseg=1) for b in range(cell2.total_nbranches): cell2.branch(b).set_ncomp(new_ncomp) @@ -168,7 +168,7 @@ def test_simulation_accuracy_swc_init_vs_set_ncomp(new_ncomp, file, SimpleMorphC fname = os.path.join(dirname, "swc_files", file) cell1 = SimpleMorphCell(fname, nseg=new_ncomp) - cell2 = SimpleMorphCell(fname, nseg=4) + cell2 = SimpleMorphCell(fname, nseg=1) for b in range(cell2.total_nbranches): cell2.branch(b).set_ncomp(new_ncomp) From 0363ecf5e3ee2a639c423233eb9ddfe9b48a6d01 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 15 Nov 2024 13:23:33 +0100 Subject: [PATCH 06/15] fix: revert stim dur --- tests/test_api_equivalence.py | 4 ++-- tests/test_cell_matches_branch.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index be0c1938..98d52d38 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -17,7 +17,7 @@ from jaxley.synapses import IonotropicSynapse -def test_api_equivalence_morphology(): +def test_api_equivalence_morphology(SimpleComp): """Test the API for how one can build morphologies from scratch.""" nseg_per_branch = 2 depth = 2 @@ -27,7 +27,7 @@ def test_api_equivalence_morphology(): parents = jnp.asarray(parents) num_branches = len(parents) - comp = jx.Compartment() + comp = SimpleComp() branch1 = jx.Branch([comp for _ in range(nseg_per_branch)]) cell1 = jx.Cell([branch1 for _ in range(num_branches)], parents=parents) diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index a2375513..dee00b2c 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -16,14 +16,14 @@ from jaxley.channels import HH -def _run_long_branch(dt, t_max, branch): +def _run_long_branch(dt, t_max, current, branch): branch.insert(HH()) branch.loc("all").make_trainable("radius", 1.0) params = branch.get_parameters() branch.loc(0.0).record() - branch.loc(0.0).stimulate(jx.step_current(0.2, 2.0, 0.1, dt, t_max)) + branch.loc(0.0).stimulate(current) def loss(params): s = jx.integrate(branch, params=params) @@ -35,14 +35,14 @@ def loss(params): return l, g -def _run_short_branches(dt, t_max, cell): +def _run_short_branches(dt, t_max, current, cell): cell.insert(HH()) cell.branch("all").loc("all").make_trainable("radius", 1.0) params = cell.get_parameters() cell.branch(0).loc(0.0).record() - cell.branch(0).loc(0.0).stimulate(jx.step_current(0.2, 2.0, 0.1, dt, t_max)) + cell.branch(0).loc(0.0).stimulate(current) def loss(params): s = jx.integrate(cell, params=params) @@ -58,9 +58,10 @@ def loss(params): def test_equivalence(SimpleBranch, SimpleCell): """Test whether a single long branch matches a cell of two shorter branches.""" dt = 0.025 - t_max = 2.0 # ms - l1, g1 = _run_long_branch(dt, t_max, SimpleBranch(8)) - l2, g2 = _run_short_branches(dt, t_max, SimpleCell(2, 4)) + t_max = 5.0 # ms + current = jx.step_current(0.5, 5.0, 0.1, dt, t_max) + l1, g1 = _run_long_branch(dt, t_max, current, SimpleBranch(8)) + l2, g2 = _run_short_branches(dt, t_max, current, SimpleCell(2, 4)) assert np.allclose(l1, l2), "Losses do not match." From 1a6fe183146831fc791993de41b9b40e8b635201 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 15 Nov 2024 15:26:41 +0100 Subject: [PATCH 07/15] fix: small fixes --- tests/conftest.py | 11 +++++------ tests/test_fixtures.py | 29 +++++++++++++++++++++++++++++ tests/test_set_ncomp.py | 10 +++++----- tests/test_viewing.py | 21 +++++++++++---------- 4 files changed, 50 insertions(+), 21 deletions(-) create mode 100644 tests/test_fixtures.py diff --git a/tests/conftest.py b/tests/conftest.py index 55f9b8b8..085f6d1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,10 @@ import os from copy import deepcopy -import jax.numpy as jnp -import numpy as np import pytest import jaxley as jx -from jaxley.channels import HH -from jaxley.synapses import IonotropicSynapse, TestSynapse +from jaxley.synapses import IonotropicSynapse @pytest.fixture(scope="session") @@ -30,7 +27,8 @@ def SimpleBranch(SimpleComp): def branch_w_shape(nseg, copy=True): if nseg not in branches: - branches[nseg] = jx.Branch([SimpleComp()] * nseg) + comp = SimpleComp() + branches[nseg] = jx.Branch([comp] * nseg) return deepcopy(branches[nseg]) if copy else branches[nseg] yield branch_w_shape @@ -49,7 +47,8 @@ def cell_w_shape(nbranches, nseg, copy=True): parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] depth += 1 parents = parents[:nbranches] - cells[key] = jx.Cell([SimpleBranch(nseg)] * nbranches, parents) + branch = SimpleBranch(nseg) + cells[key] = jx.Cell([branch] * nbranches, parents) return deepcopy(cells[key]) if copy else cells[key] yield cell_w_shape diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py new file mode 100644 index 00000000..b2d1971a --- /dev/null +++ b/tests/test_fixtures.py @@ -0,0 +1,29 @@ +import time +import warnings + + +def test_module_retrieval(SimpleNet): + t1 = time.time() + net = SimpleNet(2, 4, 4) + t2 = time.time() + net = SimpleNet(2, 4, 4) + t3 = time.time() + assert t2 - t1 > t3 - t2 + + +def test_direct_submodule_retrieval(SimpleBranch): + t1 = time.time() + branch = SimpleBranch(2, 4) + t2 = time.time() + branch = SimpleBranch(4, 4) + t3 = time.time() + assert t2 - t1 > t3 - t2 + + +def test_recursive_submodule_retrieval(SimpleNet): + t1 = time.time() + net = SimpleNet(3, 4, 4) + t2 = time.time() + net = SimpleNet(3, 4, 4) + t3 = time.time() + assert t2 - t1 > t3 - t2 diff --git a/tests/test_set_ncomp.py b/tests/test_set_ncomp.py index 53dbb200..8a9222ed 100644 --- a/tests/test_set_ncomp.py +++ b/tests/test_set_ncomp.py @@ -19,7 +19,7 @@ @pytest.mark.parametrize( "property", ["radius", "capacitance", "length", "axial_resistivity"] ) -def test_raise_for_heterogenous_modules(property, SimpleBranch): +def test_raise_for_heterogenous_modules(SimpleBranch, property): branch0 = SimpleBranch(4) branch1 = SimpleBranch(4) branch1.comp(1).set(property, 1.5) @@ -77,7 +77,7 @@ def test_raise_for_stimulus(SimpleCell): @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch( - new_ncomp, SimpleBranch + SimpleBranch, new_ncomp ): """Test whether a module built from scratch matches module built with `set_ncomp()`. @@ -109,7 +109,7 @@ def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch( @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell( - new_ncomp, SimpleBranch + SimpleBranch, new_ncomp ): """Test whether a module built from scratch matches module built with `set_ncomp()`.""" branch1 = SimpleBranch(new_ncomp) @@ -141,7 +141,7 @@ def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell( @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) @pytest.mark.parametrize("file", ["morph_250.swc"]) -def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file, SimpleMorphCell): +def test_api_equivalence_swc_lengths_and_radiuses(SimpleMorphCell, new_ncomp, file): """Test if the radiuses and lenghts of an SWC morph are reconstructed correctly.""" dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) @@ -162,7 +162,7 @@ def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file, SimpleMorphCe @pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) @pytest.mark.parametrize("file", ["morph_250.swc"]) -def test_simulation_accuracy_swc_init_vs_set_ncomp(new_ncomp, file, SimpleMorphCell): +def test_simulation_accuracy_swc_init_vs_set_ncomp(SimpleMorphCell, new_ncomp, file): """Test whether an SWC initially built with 4 ncomp works after `set_ncomp()`.""" dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index ce7a9440..547120e7 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -292,22 +292,22 @@ def test_view_attrs(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): """Check if different ways to index into Modules/Views work correctly.""" # test int, range, slice, list, np.array, pd.Index - index_types = [ - 0, - range(3), - slice(0, 3), - [0, 1, 2], - np.array([0, 1, 2]), - pd.Index([0, 1, 2]), - np.array([True, False, True, False] * 100)[: len(module.nodes)], - ] for module in [ SimpleComp(), SimpleBranch(4), SimpleCell(3, 4), - SimpleNet(2, 3, 4, connect=True), + SimpleNet(2, 3, 4), ]: + index_types = [ + 0, + range(3), + slice(0, 3), + [0, 1, 2], + np.array([0, 1, 2]), + pd.Index([0, 1, 2]), + np.array([True, False, True, False] * 100)[: len(module.nodes)], + ] # comp.comp is not allowed all_inds = module.nodes.index.to_numpy() @@ -333,6 +333,7 @@ def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, Simple module.comp(0) if isinstance(module, jx.Network): + connect(module[0, 0, :], module[1, 0, :], TestSynapse()) all_inds = module.edges.index.to_numpy() for index in index_types[:-1] + [np.array([True, False, True, False])]: expected_inds = all_inds[index] From 1094b0e360840f55b79b8c7ce0ff709c528b00d2 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 15 Nov 2024 17:11:29 +0100 Subject: [PATCH 08/15] fix: add test_fixtures test and verify that it works, although it does not seem to speed up tests significantly --- tests/conftest.py | 58 +++++++++++++++++++++-------------- tests/test_fixtures.py | 69 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 085f6d1a..a578700f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ # licensed under the Apache License Version 2.0, see import os +import warnings from copy import deepcopy import pytest @@ -12,26 +13,28 @@ @pytest.fixture(scope="session") def SimpleComp(): - comp = jx.Compartment() + comps = {} - def get_comp(copy=True): - return deepcopy(comp) if copy else comp + def get_or_build_comp(copy=True, force_init=False): + if "comp" not in comps or force_init: + comps["comp"] = jx.Compartment() + return deepcopy(comps["comp"]) if copy else comps["comp"] - yield get_comp - del comp + yield get_or_build_comp + comps = {} @pytest.fixture(scope="session") def SimpleBranch(SimpleComp): branches = {} - def branch_w_shape(nseg, copy=True): - if nseg not in branches: - comp = SimpleComp() + def get_or_build_branch(nseg, copy=True, force_init=False): + if nseg not in branches or force_init: + comp = SimpleComp(force_init=force_init) branches[nseg] = jx.Branch([comp] * nseg) return deepcopy(branches[nseg]) if copy else branches[nseg] - yield branch_w_shape + yield get_or_build_branch branches = {} @@ -39,19 +42,19 @@ def branch_w_shape(nseg, copy=True): def SimpleCell(SimpleBranch): cells = {} - def cell_w_shape(nbranches, nseg, copy=True): - if key := (nbranches, nseg) not in cells: + def get_or_build_cell(nbranches, nseg, copy=True, force_init=False): + if key := (nbranches, nseg) not in cells or force_init: parents = [-1] depth = 0 while nbranches > len(parents): parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] depth += 1 parents = parents[:nbranches] - branch = SimpleBranch(nseg) + branch = SimpleBranch(nseg=nseg, force_init=force_init) cells[key] = jx.Cell([branch] * nbranches, parents) return deepcopy(cells[key]) if copy else cells[key] - yield cell_w_shape + yield get_or_build_cell cells = {} @@ -59,15 +62,20 @@ def cell_w_shape(nbranches, nseg, copy=True): def SimpleNet(SimpleCell): nets = {} - def net_w_shape(ncells, nbranches, nseg, connect=False, copy=True): - if key := (ncells, nbranches, nseg, connect) not in nets: - net = jx.Network([SimpleCell(nbranches, nseg)] * ncells) + def get_or_build_net( + ncells, nbranches, nseg, connect=False, copy=True, force_init=False + ): + if key := (ncells, nbranches, nseg, connect) not in nets or force_init: + net = jx.Network( + [SimpleCell(nbranches=nbranches, nseg=nseg, force_init=force_init)] + * ncells + ) if connect: jx.connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) nets[key] = net return deepcopy(nets[key]) if copy else nets[key] - yield net_w_shape + yield get_or_build_net nets = {} @@ -78,13 +86,15 @@ def SimpleMorphCell(): cells = {} - def cell_w_params(fname=None, nseg=1, max_branch_len=2_000.0, copy=True): + def get_or_build_cell( + fname=None, nseg=1, max_branch_len=2_000.0, copy=True, force_init=False + ): fname = default_fname if fname is None else fname - if key := (fname, nseg, max_branch_len) not in cells: + if key := (fname, nseg, max_branch_len) not in cells or force_init: cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True) return deepcopy(cells[key]) if copy else cells[key] - yield cell_w_params + yield get_or_build_cell cells = {} @@ -95,11 +105,13 @@ def swc2jaxley(): params = {} - def swc2jaxley_params(fname=None, max_branch_len=2_000.0, sort=True): + def get_or_compute_swc2jaxley_params( + fname=None, max_branch_len=2_000.0, sort=True, force_init=False + ): fname = default_fname if fname is None else fname - if key := (fname, max_branch_len, sort) not in params: + if key := (fname, max_branch_len, sort) not in params or force_init: params[key] = jx.utils.swc.swc_to_jaxley(fname, max_branch_len, sort) return params[key] - yield swc2jaxley_params + yield get_or_compute_swc2jaxley_params params = {} diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index b2d1971a..e91866e5 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -1,29 +1,84 @@ import time import warnings +import jaxley as jx + def test_module_retrieval(SimpleNet): + t0 = time.time() + comp = jx.Compartment() + branch = jx.Branch([comp] * 4) + cell = jx.Cell([branch] * 4, [-1, 0, 0, 1]) + net = jx.Network([cell] * 2) t1 = time.time() + net = SimpleNet(2, 4, 4) t2 = time.time() + + assert ((t2 - t1) - (t1 - t0)) / ( + t1 - t0 + ) < 0.1, f"Fixture is slower than manual init." + net = SimpleNet(2, 4, 4) t3 = time.time() - assert t2 - t1 > t3 - t2 + assert ( + t1 - t0 > t2 - t1 > t3 - t2 + ), f"T_get: from pre-existing fixture {t3 - t2}, from fixture: {(t2 - t1)}, manual: {(t1 - t0)}" def test_direct_submodule_retrieval(SimpleBranch): t1 = time.time() - branch = SimpleBranch(2, 4) + branch = SimpleBranch(2, 3) t2 = time.time() - branch = SimpleBranch(4, 4) + branch = SimpleBranch(4, 3) t3 = time.time() - assert t2 - t1 > t3 - t2 + assert ( + t2 - t1 > t3 - t2 + ), f"T_get: from pre-existing fixture {t3 - t2}, from fixture: {(t2 - t1)}" def test_recursive_submodule_retrieval(SimpleNet): t1 = time.time() - net = SimpleNet(3, 4, 4) + net = SimpleNet(3, 4, 3) + t2 = time.time() + net = SimpleNet(3, 4, 3) + t3 = time.time() + assert ( + t2 - t1 > t3 - t2 + ), f"T_get: from pre-existing fixture {t3 - t2}, from fixture: {(t2 - t1)}" + + +def test_module_reinit(SimpleComp): + t0 = time.time() + comp = jx.Compartment() + t1 = time.time() + + comp = SimpleComp() + + t2 = time.time() + comp = SimpleComp() + t3 = time.time() + net = SimpleComp(force_init=True) + t4 = time.time() + + msg = f"T_get: reinit {t4 - t3}, from fixture: {(t3 - t2)}, manual: {(t1 - t0)}" + assert t1 - t0 > t4 - t3 or abs(((t1 - t0) - (t4 - t3)) / (t1 - t0)) < 0.3, msg + assert t4 - t3 > t3 - t2, msg + + +def test_module_reinit2(SimpleComp): + t0 = time.time() + comp = jx.Compartment() + t1 = time.time() + + comp = SimpleComp() + t2 = time.time() - net = SimpleNet(3, 4, 4) + comp = SimpleComp() t3 = time.time() - assert t2 - t1 > t3 - t2 + net = SimpleComp(force_init=True) + t4 = time.time() + + msg = f"T_get: reinit {t4 - t3}, from fixture: {(t3 - t2)}, manual: {(t1 - t0)}" + assert t1 - t0 > t4 - t3 or abs(((t1 - t0) - (t4 - t3)) / (t1 - t0)) < 0.3, msg + assert t4 - t3 > t3 - t2, msg From fd4545e04bd81c1ef10425d15e02334cc818749b Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 15 Nov 2024 21:58:30 +0100 Subject: [PATCH 09/15] fix: small fixes --- tests/conftest.py | 11 +++++------ tests/test_fixtures.py | 41 +++++++++++++++-------------------------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a578700f..bbf95884 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ # licensed under the Apache License Version 2.0, see import os -import warnings from copy import deepcopy import pytest @@ -18,7 +17,7 @@ def SimpleComp(): def get_or_build_comp(copy=True, force_init=False): if "comp" not in comps or force_init: comps["comp"] = jx.Compartment() - return deepcopy(comps["comp"]) if copy else comps["comp"] + return deepcopy(comps["comp"]) if copy and not force_init else comps["comp"] yield get_or_build_comp comps = {} @@ -32,7 +31,7 @@ def get_or_build_branch(nseg, copy=True, force_init=False): if nseg not in branches or force_init: comp = SimpleComp(force_init=force_init) branches[nseg] = jx.Branch([comp] * nseg) - return deepcopy(branches[nseg]) if copy else branches[nseg] + return deepcopy(branches[nseg]) if copy and not force_init else branches[nseg] yield get_or_build_branch branches = {} @@ -52,7 +51,7 @@ def get_or_build_cell(nbranches, nseg, copy=True, force_init=False): parents = parents[:nbranches] branch = SimpleBranch(nseg=nseg, force_init=force_init) cells[key] = jx.Cell([branch] * nbranches, parents) - return deepcopy(cells[key]) if copy else cells[key] + return deepcopy(cells[key]) if copy and not force_init else cells[key] yield get_or_build_cell cells = {} @@ -73,7 +72,7 @@ def get_or_build_net( if connect: jx.connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) nets[key] = net - return deepcopy(nets[key]) if copy else nets[key] + return deepcopy(nets[key]) if copy and not force_init else nets[key] yield get_or_build_net nets = {} @@ -92,7 +91,7 @@ def get_or_build_cell( fname = default_fname if fname is None else fname if key := (fname, nseg, max_branch_len) not in cells or force_init: cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True) - return deepcopy(cells[key]) if copy else cells[key] + return deepcopy(cells[key]) if copy and not force_init else cells[key] yield get_or_build_cell cells = {} diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index e91866e5..590451e9 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -1,8 +1,15 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + import time import warnings +import pytest + import jaxley as jx +pytest.skip(allow_module_level=True) + def test_module_retrieval(SimpleNet): t0 = time.time() @@ -12,14 +19,14 @@ def test_module_retrieval(SimpleNet): net = jx.Network([cell] * 2) t1 = time.time() - net = SimpleNet(2, 4, 4) + net = SimpleNet(2, 4, 4, force_init=False) t2 = time.time() assert ((t2 - t1) - (t1 - t0)) / ( t1 - t0 ) < 0.1, f"Fixture is slower than manual init." - net = SimpleNet(2, 4, 4) + net = SimpleNet(2, 4, 4, force_init=False) t3 = time.time() assert ( t1 - t0 > t2 - t1 > t3 - t2 @@ -28,9 +35,9 @@ def test_module_retrieval(SimpleNet): def test_direct_submodule_retrieval(SimpleBranch): t1 = time.time() - branch = SimpleBranch(2, 3) + branch = SimpleBranch(2, 3, force_init=False) t2 = time.time() - branch = SimpleBranch(4, 3) + branch = SimpleBranch(4, 3, force_init=False) t3 = time.time() assert ( t2 - t1 > t3 - t2 @@ -39,9 +46,9 @@ def test_direct_submodule_retrieval(SimpleBranch): def test_recursive_submodule_retrieval(SimpleNet): t1 = time.time() - net = SimpleNet(3, 4, 3) + net = SimpleNet(3, 4, 3, force_init=False) t2 = time.time() - net = SimpleNet(3, 4, 3) + net = SimpleNet(3, 4, 3, force_init=False) t3 = time.time() assert ( t2 - t1 > t3 - t2 @@ -53,28 +60,10 @@ def test_module_reinit(SimpleComp): comp = jx.Compartment() t1 = time.time() - comp = SimpleComp() - - t2 = time.time() - comp = SimpleComp() - t3 = time.time() - net = SimpleComp(force_init=True) - t4 = time.time() - - msg = f"T_get: reinit {t4 - t3}, from fixture: {(t3 - t2)}, manual: {(t1 - t0)}" - assert t1 - t0 > t4 - t3 or abs(((t1 - t0) - (t4 - t3)) / (t1 - t0)) < 0.3, msg - assert t4 - t3 > t3 - t2, msg - - -def test_module_reinit2(SimpleComp): - t0 = time.time() - comp = jx.Compartment() - t1 = time.time() - - comp = SimpleComp() + comp = SimpleComp(force_init=False) t2 = time.time() - comp = SimpleComp() + comp = SimpleComp(force_init=False) t3 = time.time() net = SimpleComp(force_init=True) t4 = time.time() From 94e7903528c36b855501f21614bc12b9f4dcfc00 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 15 Nov 2024 22:10:51 +0100 Subject: [PATCH 10/15] doc: add docstrings to fixtures --- tests/conftest.py | 109 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 97 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bbf95884..89621245 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,9 +12,20 @@ @pytest.fixture(scope="session") def SimpleComp(): + """Fixture for creating or retrieving an already created compartment.""" comps = {} - def get_or_build_comp(copy=True, force_init=False): + def get_or_build_comp( + copy: bool = True, force_init: bool = False + ) -> jx.Compartment: + """Create or retrieve a compartment. + + Args: + copy: Whether to return a copy of the compartment. Default is True. + force_init: Force the init from scratch. Default is False. + + Returns: + jx.Compartment().""" if "comp" not in comps or force_init: comps["comp"] = jx.Compartment() return deepcopy(comps["comp"]) if copy and not force_init else comps["comp"] @@ -25,9 +36,23 @@ def get_or_build_comp(copy=True, force_init=False): @pytest.fixture(scope="session") def SimpleBranch(SimpleComp): + """Fixture for creating or retrieving an already created branch.""" branches = {} - def get_or_build_branch(nseg, copy=True, force_init=False): + def get_or_build_branch( + nseg: int, copy: bool = True, force_init: bool = False + ) -> jx.Branch: + """Create or retrieve a branch. + + If a branch with the same number of compartments already exists, it is returned. + + Args: + nseg: Number of compartments in the branch. + copy: Whether to return a copy of the branch. Default is True. + force_init: Force the init from scratch. Default is False. + + Returns: + jx.Branch().""" if nseg not in branches or force_init: comp = SimpleComp(force_init=force_init) branches[nseg] = jx.Branch([comp] * nseg) @@ -39,9 +64,25 @@ def get_or_build_branch(nseg, copy=True, force_init=False): @pytest.fixture(scope="session") def SimpleCell(SimpleBranch): + """Fixture for creating or retrieving an already created cell.""" cells = {} - def get_or_build_cell(nbranches, nseg, copy=True, force_init=False): + def get_or_build_cell( + nbranches: int, nseg: int, copy: bool = True, force_init: bool = False + ) -> jx.Cell: + """Create or retrieve a cell. + + If a cell with the same number of branches and compartments already exists, it + is returned. The branch strcuture is assumed as [-1, 0, 0, 1, 1, 2, 2, ...]. + + Args: + nbranches: Number of branches in the cell. + nseg: Number of compartments in each branch. + copy: Whether to return a copy of the cell. Default is True. + force_init: Force the init from scratch. Default is False. + + Returns: + jx.Cell().""" if key := (nbranches, nseg) not in cells or force_init: parents = [-1] depth = 0 @@ -59,11 +100,32 @@ def get_or_build_cell(nbranches, nseg, copy=True, force_init=False): @pytest.fixture(scope="session") def SimpleNet(SimpleCell): + """Fixture for creating or retrieving an already created network.""" nets = {} def get_or_build_net( - ncells, nbranches, nseg, connect=False, copy=True, force_init=False - ): + ncells: int, + nbranches: int, + nseg: int, + connect: bool = False, + copy: bool = True, + force_init: bool = False, + ) -> jx.Network: + """Create or retrieve a network. + + If a network with the same number of cells, branches, compartments, and + connections already exists, it is returned. + + Args: + ncells: Number of cells in the network. + nbranches: Number of branches in each cell. + nseg: Number of compartments in each branch. + connect: Whether to connect the first two cells in the network. + copy: Whether to return a copy of the network. Default is True. + force_init: Force the init from scratch. Default is False. + + Returns: + jx.Network().""" if key := (ncells, nbranches, nseg, connect) not in nets or force_init: net = jx.Network( [SimpleCell(nbranches=nbranches, nseg=nseg, force_init=force_init)] @@ -80,14 +142,33 @@ def get_or_build_net( @pytest.fixture(scope="session") def SimpleMorphCell(): - dirname = os.path.dirname(__file__) - default_fname = os.path.join(dirname, "swc_files", "morph.swc") # n120 + """Fixture for creating or retrieving an already created morpholgy.""" cells = {} def get_or_build_cell( - fname=None, nseg=1, max_branch_len=2_000.0, copy=True, force_init=False - ): + fname: str = None, + nseg: int = 1, + max_branch_len: float = 2_000.0, + copy: bool = True, + force_init: bool = False, + ) -> jx.Cell: + """Create or retrieve a cell from an SWC file. + + If a cell with the same SWC file, number of compartments, and maximum branch + length already exists, it is returned. + + Args: + fname: Path to the SWC file. + nseg: Number of compartments in each branch. + max_branch_len: Maximum length of a branch. + copy: Whether to return a copy of the cell. Default is True. + force_init: Force the init from scratch. Default is False. + + Returns: + jx.Cell().""" + dirname = os.path.dirname(__file__) + default_fname = os.path.join(dirname, "swc_files", "morph.swc") fname = default_fname if fname is None else fname if key := (fname, nseg, max_branch_len) not in cells or force_init: cells[key] = jx.read_swc(fname, nseg, max_branch_len, assign_groups=True) @@ -99,14 +180,18 @@ def get_or_build_cell( @pytest.fixture(scope="session") def swc2jaxley(): - dirname = os.path.dirname(__file__) - default_fname = os.path.join(dirname, "swc_files", "morph.swc") # n120 + """Fixture for creating or retrieving an already computed params of a morphology.""" params = {} def get_or_compute_swc2jaxley_params( - fname=None, max_branch_len=2_000.0, sort=True, force_init=False + fname: str = None, + max_branch_len: float = 2_000.0, + sort: bool = True, + force_init: bool = False, ): + dirname = os.path.dirname(__file__) + default_fname = os.path.join(dirname, "swc_files", "morph.swc") fname = default_fname if fname is None else fname if key := (fname, max_branch_len, sort) not in params or force_init: params[key] = jx.utils.swc.swc_to_jaxley(fname, max_branch_len, sort) From 1304e17a97cbd24936ccbfae33a449ddce2ece26 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sat, 16 Nov 2024 01:06:17 +0100 Subject: [PATCH 11/15] fix: reduce number of different stimuli and shorten where possible --- tests/conftest.py | 16 +++++++++ tests/jaxley_identical/test_basic_modules.py | 34 ++++++++++++------- tests/jaxley_identical/test_grad.py | 4 ++- .../test_radius_and_length.py | 20 ++++++----- tests/jaxley_identical/test_swc.py | 10 +++--- tests/jaxley_vs_neuron/test_branch.py | 6 ++-- tests/jaxley_vs_neuron/test_cell.py | 3 +- tests/jaxley_vs_neuron/test_comp.py | 3 +- tests/test_api_equivalence.py | 29 +++++++++++----- tests/test_cell_matches_branch.py | 4 ++- tests/test_channels.py | 8 +++-- tests/test_composability_of_modules.py | 15 ++++---- tests/test_data_feeding.py | 18 +++++++--- tests/test_grad.py | 8 +++-- tests/test_make_trainable.py | 16 ++++++--- tests/test_optimize.py | 10 ++++-- tests/test_record_and_stimulate.py | 12 +++++-- tests/test_shared_state.py | 4 ++- tests/test_solver.py | 4 ++- 19 files changed, 159 insertions(+), 65 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 89621245..d6014356 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,22 @@ from jaxley.synapses import IonotropicSynapse +@pytest.fixture(scope="session") +def step_current(): + def get_step_current( + i_delay: float = 0.5, + i_dur: float = 1.0, + i_amp: float = 0.1, + dt: float = 0.025, + t_max: float = 3.0, + i_offset: float = 0.0, + ): + """Create a step current stimulus.""" + return jx.step_current(i_delay, i_dur, i_amp, dt, t_max, i_offset) + + yield get_step_current + + @pytest.fixture(scope="session") def SimpleComp(): """Fixture for creating or retrieving an already created compartment.""" diff --git a/tests/jaxley_identical/test_basic_modules.py b/tests/jaxley_identical/test_basic_modules.py index 577878c2..4b46a5e8 100644 --- a/tests/jaxley_identical/test_basic_modules.py +++ b/tests/jaxley_identical/test_basic_modules.py @@ -25,8 +25,9 @@ @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_compartment(voltage_solver, SimpleComp, SimpleBranch, SimpleCell, SimpleNet): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) tolerance = 1e-8 voltages_081123 = jnp.asarray( @@ -87,8 +88,9 @@ def test_compartment(voltage_solver, SimpleComp, SimpleBranch, SimpleCell, Simpl @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_branch(voltage_solver, SimpleBranch): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) branch = SimpleBranch(2) branch.insert(HH()) @@ -121,8 +123,9 @@ def test_branch(voltage_solver, SimpleBranch): def test_branch_fwd_euler_uneven_radiuses(SimpleBranch): dt = 0.025 # ms - t_max = 10.0 # ms - current = jx.step_current(0.5, 1.0, 2.0, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=2.0, delta_t=0.025, t_max=10.0 + ) branch = SimpleBranch(8) branch.set("axial_resistivity", 500.0) @@ -159,8 +162,9 @@ def test_branch_fwd_euler_uneven_radiuses(SimpleBranch): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_cell(voltage_solver, SimpleCell): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) cell = SimpleCell(3, 2) cell.insert(HH()) @@ -194,8 +198,9 @@ def test_cell(voltage_solver, SimpleCell): def test_cell_unequal_compartment_number(SimpleBranch): """Tests a cell where every branch has a different number of compartments.""" dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.1, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + ) branch1 = SimpleBranch(nseg=1) branch2 = SimpleBranch(nseg=2) @@ -227,8 +232,9 @@ def test_cell_unequal_compartment_number(SimpleBranch): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_net(voltage_solver, SimpleNet): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) net = SimpleNet(2, 3, 2) @@ -315,7 +321,9 @@ def test_complex_net(voltage_solver, SimpleNet): "TestSynapse_gC", 0.24 / point_process_to_dist_factor ) - current = jx.step_current(0.5, 0.5, 0.1, 0.025, 10.0) + current = jx.step_current( + i_delay=0.5, i_dur=0.5, i_amp=0.1, delta_t=0.025, t_max=10.0 + ) for i in range(3): net.cell(i).branch(0).loc(0.0).stimulate(current) diff --git a/tests/jaxley_identical/test_grad.py b/tests/jaxley_identical/test_grad.py index b6c98d30..bfd1de84 100644 --- a/tests/jaxley_identical/test_grad.py +++ b/tests/jaxley_identical/test_grad.py @@ -51,7 +51,9 @@ def test_network_grad(SimpleNet): "TestSynapse_gC", 0.24 / point_process_to_dist_factor ) - current = jx.step_current(0.5, 0.5, 0.1, 0.025, 10.0) + current = jx.step_current( + i_delay=0.5, i_dur=0.5, i_amp=0.1, delta_t=0.025, t_max=10.0 + ) for i in range(3): net.cell(i).branch(0).loc(0.0).stimulate(current) diff --git a/tests/jaxley_identical/test_radius_and_length.py b/tests/jaxley_identical/test_radius_and_length.py index f0ef3c83..cd19b020 100644 --- a/tests/jaxley_identical/test_radius_and_length.py +++ b/tests/jaxley_identical/test_radius_and_length.py @@ -25,8 +25,9 @@ @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_radius_and_length_compartment(voltage_solver, SimpleComp): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) comp = SimpleComp() @@ -65,8 +66,9 @@ def test_radius_and_length_compartment(voltage_solver, SimpleComp): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_radius_and_length_branch(voltage_solver, SimpleBranch): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) branch = SimpleBranch(nseg=2) @@ -105,8 +107,9 @@ def test_radius_and_length_branch(voltage_solver, SimpleBranch): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_radius_and_length_cell(voltage_solver, SimpleCell): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) num_branches = 3 cell = SimpleCell(num_branches, nseg=2) @@ -149,8 +152,9 @@ def test_radius_and_length_cell(voltage_solver, SimpleCell): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_radius_and_length_net(voltage_solver, SimpleNet): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.02, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=5.0 + ) num_branches = 3 net = SimpleNet(2, num_branches, 2) diff --git a/tests/jaxley_identical/test_swc.py b/tests/jaxley_identical/test_swc.py index eccdfd19..ea15cf94 100644 --- a/tests/jaxley_identical/test_swc.py +++ b/tests/jaxley_identical/test_swc.py @@ -26,8 +26,9 @@ @pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"]) def test_swc_cell(voltage_solver: str, file: str, SimpleMorphCell): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.2, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.2, delta_t=0.025, t_max=5.0 + ) dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "../swc_files", file) @@ -86,8 +87,9 @@ def test_swc_cell(voltage_solver: str, file: str, SimpleMorphCell): @pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"]) def test_swc_net(voltage_solver: str, SimpleMorphCell): dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.2, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.2, delta_t=0.025, t_max=5.0 + ) dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "../swc_files/morph.swc") diff --git a/tests/jaxley_vs_neuron/test_branch.py b/tests/jaxley_vs_neuron/test_branch.py index d818cc58..fa4022ef 100644 --- a/tests/jaxley_vs_neuron/test_branch.py +++ b/tests/jaxley_vs_neuron/test_branch.py @@ -64,7 +64,8 @@ def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max, solver): branch.set("HH_n", 0.3644787002343737) branch.set("v", -62.0) - branch.loc(0.0).stimulate(jx.step_current(i_delay, i_dur, i_amp, dt, t_max)) + current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max) + branch.loc(0.0).stimulate(current) branch.loc(0.0).record() branch.loc(1.0).record() @@ -207,7 +208,8 @@ def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams, capacitances, solve counter += 1 # 0.02 is fine here because nseg=8 for NEURON, but nseg=16 for jaxley. - branch.loc(0.02).stimulate(jx.step_current(i_delay, i_dur, i_amp, dt, t_max)) + current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max) + branch.loc(0.02).stimulate(current) branch.loc(0.02).record() branch.loc(0.52).record() branch.loc(0.98).record() diff --git a/tests/jaxley_vs_neuron/test_cell.py b/tests/jaxley_vs_neuron/test_cell.py index 06260a43..22c8d6ee 100644 --- a/tests/jaxley_vs_neuron/test_cell.py +++ b/tests/jaxley_vs_neuron/test_cell.py @@ -59,7 +59,8 @@ def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max, solver): cell.set("HH_n", 0.3644787002343737) cell.set("v", -62.0) - cell.branch(0).loc(0.0).stimulate(jx.step_current(i_delay, i_dur, i_amp, dt, t_max)) + current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max) + cell.branch(0).loc(0.0).stimulate(current) cell.branch(0).loc(0.0).record() cell.branch(1).loc(1.0).record() cell.branch(2).loc(1.0).record() diff --git a/tests/jaxley_vs_neuron/test_comp.py b/tests/jaxley_vs_neuron/test_comp.py index d895af93..939bdad1 100644 --- a/tests/jaxley_vs_neuron/test_comp.py +++ b/tests/jaxley_vs_neuron/test_comp.py @@ -53,7 +53,8 @@ def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max): comp.set("v", -62.0) comp.set("capacitance", 5.0) - comp.stimulate(jx.step_current(i_delay, i_dur, i_amp, dt, t_max)) + current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max) + comp.stimulate(current) comp.record() voltages = jx.integrate(comp, delta_t=dt) diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index 98d52d38..2146a870 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -38,7 +38,9 @@ def test_api_equivalence_morphology(SimpleComp): cell1.branch(2).loc(0.4).record() cell2.branch(2).loc(0.4).record() - current = jx.step_current(0.5, 1.0, 1.0, dt, 3.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) cell1.branch(1).loc(1.0).stimulate(current) cell2.branch(1).loc(1.0).stimulate(current) @@ -53,7 +55,9 @@ def test_solver_backends_comp(SimpleComp): """Test whether ways of adding synapses are equivalent.""" comp = SimpleComp() - current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) comp.stimulate(current) comp.record() @@ -69,7 +73,9 @@ def test_solver_backends_branch(SimpleBranch): """Test whether ways of adding synapses are equivalent.""" branch = SimpleBranch(4) - current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) branch.loc(0.0).stimulate(current) branch.loc(0.5).record() @@ -86,7 +92,9 @@ def test_solver_backends_cell(SimpleCell): """Test whether ways of adding synapses are equivalent.""" cell = SimpleCell(4, 4) - current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) cell.branch(0).loc(0.0).stimulate(current) cell.branch(0).loc(0.5).record() cell.branch(3).loc(0.5).record() @@ -114,7 +122,9 @@ def test_solver_backends_net(SimpleNet): IonotropicSynapse(), ) - current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() net.cell(1).branch(3).loc(0.5).record() @@ -152,7 +162,9 @@ def test_api_equivalence_synapses(SimpleNet): connect(pre, post, IonotropicSynapse()) for net in [net1, net2]: - current = jx.step_current(0.5, 1.0, 0.5, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() net.cell(1).branch(3).loc(0.5).record() @@ -183,8 +195,9 @@ def test_api_equivalence_network_matches_cell(SimpleBranch): This runs an unequal number of compartments per branch.""" dt = 0.025 # ms - t_max = 5.0 # ms - current = jx.step_current(0.5, 1.0, 0.1, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) branch1 = SimpleBranch(nseg=1) branch2 = SimpleBranch(nseg=2) diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index dee00b2c..18d06555 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -59,7 +59,9 @@ def test_equivalence(SimpleBranch, SimpleCell): """Test whether a single long branch matches a cell of two shorter branches.""" dt = 0.025 t_max = 5.0 # ms - current = jx.step_current(0.5, 5.0, 0.1, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) l1, g1 = _run_long_branch(dt, t_max, current, SimpleBranch(8)) l2, g2 = _run_short_branches(dt, t_max, current, SimpleCell(2, 4)) diff --git a/tests/test_channels.py b/tests/test_channels.py index e4722334..f9c4cf23 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -273,7 +273,9 @@ def test_init_states_complex_channel(SimpleCell): cell.init_states() - current = jx.step_current(1.0, 1.0, 0.1, 0.025, 3.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) cell.branch(2).comp(0).stimulate(current) cell.branch(2).comp(0).record() voltages = jx.integrate(cell) @@ -332,7 +334,9 @@ def compute_current(self, states, v, params): dt = 0.025 # ms t_max = 10.0 # ms cell = SimpleCell(1, 1) - cell.branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, dt, t_max)) + cell.branch(0).loc(0.0).stimulate( + jx.step_current(i_delay=0.5, i_dur=0.5, i_amp=0.1, delta_t=0.025, t_max=10.0) + ) cell.insert(User()) cell.insert(Dummy1()) diff --git a/tests/test_composability_of_modules.py b/tests/test_composability_of_modules.py index f8d2cbbb..3ad5a494 100644 --- a/tests/test_composability_of_modules.py +++ b/tests/test_composability_of_modules.py @@ -15,8 +15,9 @@ def test_compose_branch(): """Test inserting to comp and composing to branch equals inserting to branch.""" dt = 0.025 - t_max = 3.0 - current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) comp1 = jx.Compartment() comp1.insert(HH()) @@ -41,8 +42,9 @@ def test_compose_cell(): """Test inserting to branch and composing to cell equals inserting to cell.""" nseg_per_branch = 4 dt = 0.025 - t_max = 3.0 - current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) comp = jx.Compartment() @@ -69,8 +71,9 @@ def test_compose_net(): """Test inserting to cell and composing to net equals inserting to net.""" nseg_per_branch = 4 dt = 0.025 - t_max = 3.0 - current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) comp = jx.Compartment() branch = jx.Branch(comp, nseg_per_branch) diff --git a/tests/test_data_feeding.py b/tests/test_data_feeding.py index 744c5720..93f6f4c9 100644 --- a/tests/test_data_feeding.py +++ b/tests/test_data_feeding.py @@ -24,11 +24,15 @@ def test_constant_and_data_stimulus(SimpleCell): i_amp_const = 0.02 i_amps_data = jnp.asarray([0.01, 0.005]) - current = jx.step_current(1.0, 1.0, i_amp_const, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=3.0 + ) cell.branch(1).loc(0.6).stimulate(current) def provide_data(i_amps): - current = jx.datapoint_to_step_currents(1.0, 1.0, i_amps, 0.025, 5.0) + current = jx.datapoint_to_step_currents( + i_delay=0.5, i_dur=1.0, i_amp=i_amps, delta_t=0.025, t_max=3.0 + ) data_stimuli = None data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[0], data_stimuli) data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[1], data_stimuli) @@ -43,7 +47,9 @@ def simulate(i_amps): cell.delete_stimuli() i_amp_summed = i_amp_const + jnp.sum(i_amps_data) - current_sum = jx.step_current(1.0, 1.0, i_amp_summed, 0.025, 5.0) + current_sum = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=3.0 + ) cell.branch(1).loc(0.6).stimulate(current_sum) v_stim = jx.integrate(cell) @@ -59,7 +65,7 @@ def test_data_vs_constant_stimulus(SimpleCell): i_amps_data = jnp.asarray([0.01, 0.005]) def provide_data(i_amps): - current = jx.datapoint_to_step_currents(1.0, 1.0, i_amps, 0.025, 5.0) + current = jx.datapoint_to_step_currents(0.5, 1.0, i_amps, 0.025, 3.0) data_stimuli = None data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[0], data_stimuli) data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[1], data_stimuli) @@ -74,7 +80,9 @@ def simulate(i_amps): cell.delete_stimuli() i_amp_summed = jnp.sum(i_amps_data) - current_sum = jx.step_current(1.0, 1.0, i_amp_summed, 0.025, 5.0) + current_sum = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=3.0 + ) cell.branch(1).loc(0.6).stimulate(current_sum) v_stim = jx.integrate(cell) diff --git a/tests/test_grad.py b/tests/test_grad.py index f719d570..9c3df885 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -25,7 +25,9 @@ def simulate_with_params(params): comp = SimpleComp(copy=True) comp.insert(HH()) comp.record() - comp.stimulate(jx.step_current(0.1, 0.2, 0.1, 0.025, 5.0)) + comp.stimulate( + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0) + ) val = 0.2 if key == "HH_m" else -70.0 step_size = 0.01 @@ -59,7 +61,9 @@ def simulate_with_params(params): branch = SimpleBranch(4) branch.loc(0.0).record() - branch.loc(0.0).stimulate(jx.step_current(0.1, 0.2, 0.1, 0.025, 5.0)) + branch.loc(0.0).stimulate( + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0) + ) branch.loc(0.0).insert(HH()) val = 0.2 if key == "HH_m" else -70.0 diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 1a84bc11..a3122aed 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -321,7 +321,9 @@ def test_make_trainable_corresponds_to_set_pospischil(SimpleNet): net1.cell(1).branch(1).loc(0.0).record() net2.cell(1).branch(1).loc(0.0).record() - current = jx.step_current(2.0, 3.0, 0.2, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) net1.cell(0).branch(1).loc(0.0).stimulate(current) net2.cell(0).branch(1).loc(0.0).stimulate(current) voltages1 = jx.integrate(net1, params=params1) @@ -386,7 +388,9 @@ def test_data_set_vs_make_trainable_pospischil(SimpleNet): net1.cell(1).branch(1).loc(0.0).record() net2.cell(1).branch(1).loc(0.0).record() - current = jx.step_current(2.0, 3.0, 0.2, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + ) net1.cell(0).branch(1).loc(0.0).stimulate(current) net2.cell(0).branch(1).loc(0.0).stimulate(current) voltages1 = jx.integrate(net1, params=params1) @@ -397,7 +401,9 @@ def test_data_set_vs_make_trainable_pospischil(SimpleNet): def test_data_set_vs_make_trainable_network(SimpleNet): net1 = SimpleNet(2, 4, 1) net2 = SimpleNet(2, 4, 1) - current = jx.step_current(0.1, 4.0, 0.1, 0.025, 5.0) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) for net in [net1, net2]: net.insert(HH()) pre = net.cell(0).branch(0).loc(0.0) @@ -483,7 +489,9 @@ def test_write_trainables(SimpleNet): net.insert(HH()) net.cell(0).branch(0).comp(0).record() net.cell(1).branch(0).comp(0).record() - net.cell(0).branch(0).comp(0).stimulate(jx.step_current(0.1, 4.0, 0.1, 0.025, 5.0)) + net.cell(0).branch(0).comp(0).stimulate( + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0) + ) net.make_trainable("radius") net.cell(0).make_trainable("length") diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 52cc7ea6..d094604e 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -22,7 +22,10 @@ def test_type_optimizer_api(SimpleComp): comp = SimpleComp(copy=True) comp.insert(HH()) comp.record() - comp.stimulate(jx.step_current(0.1, 3.0, 0.1, 0.025, 5.0)) + current = jx.step_current( + i_delay=0.1, i_dur=3.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + ) + comp.stimulate(current) def simulate(params): return jx.integrate(comp, params=params) @@ -53,7 +56,10 @@ def test_type_optimizer(SimpleComp): comp = SimpleComp(copy=True) comp.insert(HH()) comp.record() - comp.stimulate(jx.step_current(0.1, 3.0, 0.1, 0.025, 5.0)) + current = jx.step_current( + i_delay=0.1, i_dur=3.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + ) + comp.stimulate(current) comp.set("HH_gNa", 0.4) comp.set("radius", 30.0) diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 25186a13..9676097e 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -23,7 +23,9 @@ def test_record_and_stimulate_api(SimpleCell): cell.branch(0).loc(0.0).record() cell.branch(1).loc(1.0).record() - current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) + current = jx.step_current( + i_delay=0.0, i_dur=0.0, i_amp=0.0, delta_t=0.025, t_max=0.1 + ) cell.branch(1).loc(1.0).stimulate(current) cell.delete_recordings() @@ -34,7 +36,9 @@ def test_record_shape(SimpleCell): """Test the API for recording and stimulating.""" cell = SimpleCell(3, 2) - current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) + current = jx.step_current( + i_delay=0.0, i_dur=0.0, i_amp=0.0, delta_t=0.025, t_max=0.1 + ) cell.branch(1).loc(1.0).stimulate(current) cell.branch(0).loc(0.0).record() @@ -66,7 +70,9 @@ def test_record_synaptic_and_membrane_states(SimpleNet): fully_connect(net.cell([1]), net.cell([2]), TestSynapse()) fully_connect(net.cell([2]), net.cell([0]), IonotropicSynapse()) - current = jx.step_current(1.0, 80.0, 0.02, 0.025, 100.0) + current = jx.step_current( + i_delay=1.0, i_dur=80.0, i_amp=0.02, delta_t=0.025, t_max=100.0 + ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(2).branch(0).loc(0.0).record("v") diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 3e7642ce..3835f3e9 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -202,7 +202,9 @@ def test_shared_state(): voltages = [] for comp in [comp1, comp2, comp3]: comp.record() - current = jx.step_current(0.1, 0.1, 0.1, 0.025, 0.3) + current = jx.step_current( + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + ) comp.stimulate(current) voltages.append(jx.integrate(comp)) diff --git a/tests/test_solver.py b/tests/test_solver.py index 88a86c59..5e6de2b7 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -30,7 +30,9 @@ def test_fwd_euler_and_crank_nicolson(SimpleNet): Euler.""" net = SimpleNet(2, 1, 4, connect=True) - current = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0) + current = jx.step_current( + i_delay=0.5, i_dur=0.5, i_amp=0.1, delta_t=0.025, t_max=10.0 + ) net.cell(0).branch(0).comp(0).stimulate(current) net.cell(1).branch(0).comp(3).record() From 5fc35f886f61f649fd36962cdb4fc8bd1b19b140 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 19 Nov 2024 11:26:37 +0100 Subject: [PATCH 12/15] fix: standardize and shorten stimuli --- tests/conftest.py | 16 ---------------- tests/test_api_equivalence.py | 14 +++++++------- tests/test_cell_matches_branch.py | 2 +- tests/test_channels.py | 6 +++--- tests/test_composability_of_modules.py | 6 +++--- tests/test_data_feeding.py | 12 ++++++------ tests/test_grad.py | 4 ++-- tests/test_make_trainable.py | 8 ++++---- tests/test_optimize.py | 4 ++-- tests/test_record_and_stimulate.py | 2 +- tests/test_shared_state.py | 2 +- tests/test_solver.py | 2 +- tests/test_viewing.py | 2 +- 13 files changed, 32 insertions(+), 48 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d6014356..89621245 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,22 +10,6 @@ from jaxley.synapses import IonotropicSynapse -@pytest.fixture(scope="session") -def step_current(): - def get_step_current( - i_delay: float = 0.5, - i_dur: float = 1.0, - i_amp: float = 0.1, - dt: float = 0.025, - t_max: float = 3.0, - i_offset: float = 0.0, - ): - """Create a step current stimulus.""" - return jx.step_current(i_delay, i_dur, i_amp, dt, t_max, i_offset) - - yield get_step_current - - @pytest.fixture(scope="session") def SimpleComp(): """Fixture for creating or retrieving an already created compartment.""" diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index 2146a870..fa3ceaf2 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -39,7 +39,7 @@ def test_api_equivalence_morphology(SimpleComp): cell2.branch(2).loc(0.4).record() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) cell1.branch(1).loc(1.0).stimulate(current) cell2.branch(1).loc(1.0).stimulate(current) @@ -56,7 +56,7 @@ def test_solver_backends_comp(SimpleComp): comp = SimpleComp() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) comp.stimulate(current) comp.record() @@ -74,7 +74,7 @@ def test_solver_backends_branch(SimpleBranch): branch = SimpleBranch(4) current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) branch.loc(0.0).stimulate(current) branch.loc(0.5).record() @@ -93,7 +93,7 @@ def test_solver_backends_cell(SimpleCell): cell = SimpleCell(4, 4) current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) cell.branch(0).loc(0.0).stimulate(current) cell.branch(0).loc(0.5).record() @@ -123,7 +123,7 @@ def test_solver_backends_net(SimpleNet): ) current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() @@ -163,7 +163,7 @@ def test_api_equivalence_synapses(SimpleNet): for net in [net1, net2]: current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() @@ -196,7 +196,7 @@ def test_api_equivalence_network_matches_cell(SimpleBranch): This runs an unequal number of compartments per branch.""" dt = 0.025 # ms current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) branch1 = SimpleBranch(nseg=1) diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index 18d06555..e415cf68 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -60,7 +60,7 @@ def test_equivalence(SimpleBranch, SimpleCell): dt = 0.025 t_max = 5.0 # ms current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) l1, g1 = _run_long_branch(dt, t_max, current, SimpleBranch(8)) l2, g2 = _run_short_branches(dt, t_max, current, SimpleCell(2, 4)) diff --git a/tests/test_channels.py b/tests/test_channels.py index f9c4cf23..c8e725ad 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -274,7 +274,7 @@ def test_init_states_complex_channel(SimpleCell): cell.init_states() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) cell.branch(2).comp(0).stimulate(current) cell.branch(2).comp(0).record() @@ -332,10 +332,10 @@ def compute_current(self, states, v, params): return 0.01 * jnp.ones_like(v) dt = 0.025 # ms - t_max = 10.0 # ms + t_max = 2.0 # ms cell = SimpleCell(1, 1) cell.branch(0).loc(0.0).stimulate( - jx.step_current(i_delay=0.5, i_dur=0.5, i_amp=0.1, delta_t=0.025, t_max=10.0) + jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) ) cell.insert(User()) diff --git a/tests/test_composability_of_modules.py b/tests/test_composability_of_modules.py index 3ad5a494..9b834293 100644 --- a/tests/test_composability_of_modules.py +++ b/tests/test_composability_of_modules.py @@ -16,7 +16,7 @@ def test_compose_branch(): """Test inserting to comp and composing to branch equals inserting to branch.""" dt = 0.025 current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) comp1 = jx.Compartment() @@ -43,7 +43,7 @@ def test_compose_cell(): nseg_per_branch = 4 dt = 0.025 current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) comp = jx.Compartment() @@ -72,7 +72,7 @@ def test_compose_net(): nseg_per_branch = 4 dt = 0.025 current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) comp = jx.Compartment() diff --git a/tests/test_data_feeding.py b/tests/test_data_feeding.py index 93f6f4c9..62375cb2 100644 --- a/tests/test_data_feeding.py +++ b/tests/test_data_feeding.py @@ -21,17 +21,17 @@ def test_constant_and_data_stimulus(SimpleCell): # test data_stimulate and jit works with trainable parameters see #467 cell.make_trainable("radius") - i_amp_const = 0.02 + i_amp_const = 0.1 i_amps_data = jnp.asarray([0.01, 0.005]) current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.02, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=i_amp_const, delta_t=0.025, t_max=2.0 ) cell.branch(1).loc(0.6).stimulate(current) def provide_data(i_amps): current = jx.datapoint_to_step_currents( - i_delay=0.5, i_dur=1.0, i_amp=i_amps, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=i_amps, delta_t=0.025, t_max=2.0 ) data_stimuli = None data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[0], data_stimuli) @@ -48,7 +48,7 @@ def simulate(i_amps): cell.delete_stimuli() i_amp_summed = i_amp_const + jnp.sum(i_amps_data) current_sum = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=2.0 ) cell.branch(1).loc(0.6).stimulate(current_sum) @@ -65,7 +65,7 @@ def test_data_vs_constant_stimulus(SimpleCell): i_amps_data = jnp.asarray([0.01, 0.005]) def provide_data(i_amps): - current = jx.datapoint_to_step_currents(0.5, 1.0, i_amps, 0.025, 3.0) + current = jx.datapoint_to_step_currents(0.1, 1.0, i_amps, 0.025, 2.0) data_stimuli = None data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[0], data_stimuli) data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[1], data_stimuli) @@ -81,7 +81,7 @@ def simulate(i_amps): cell.delete_stimuli() i_amp_summed = jnp.sum(i_amps_data) current_sum = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=2.0 ) cell.branch(1).loc(0.6).stimulate(current_sum) diff --git a/tests/test_grad.py b/tests/test_grad.py index 9c3df885..3cbcebde 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -26,7 +26,7 @@ def simulate_with_params(params): comp.insert(HH()) comp.record() comp.stimulate( - jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0) + jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) ) val = 0.2 if key == "HH_m" else -70.0 @@ -62,7 +62,7 @@ def simulate_with_params(params): branch = SimpleBranch(4) branch.loc(0.0).record() branch.loc(0.0).stimulate( - jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0) + jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) ) branch.loc(0.0).insert(HH()) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index a3122aed..3f138b19 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -322,7 +322,7 @@ def test_make_trainable_corresponds_to_set_pospischil(SimpleNet): net2.cell(1).branch(1).loc(0.0).record() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) net1.cell(0).branch(1).loc(0.0).stimulate(current) net2.cell(0).branch(1).loc(0.0).stimulate(current) @@ -389,7 +389,7 @@ def test_data_set_vs_make_trainable_pospischil(SimpleNet): net2.cell(1).branch(1).loc(0.0).record() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.5, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) net1.cell(0).branch(1).loc(0.0).stimulate(current) net2.cell(0).branch(1).loc(0.0).stimulate(current) @@ -402,7 +402,7 @@ def test_data_set_vs_make_trainable_network(SimpleNet): net1 = SimpleNet(2, 4, 1) net2 = SimpleNet(2, 4, 1) current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) for net in [net1, net2]: net.insert(HH()) @@ -490,7 +490,7 @@ def test_write_trainables(SimpleNet): net.cell(0).branch(0).comp(0).record() net.cell(1).branch(0).comp(0).record() net.cell(0).branch(0).comp(0).stimulate( - jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0) + jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) ) net.make_trainable("radius") diff --git a/tests/test_optimize.py b/tests/test_optimize.py index d094604e..5668e5e8 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -23,7 +23,7 @@ def test_type_optimizer_api(SimpleComp): comp.insert(HH()) comp.record() current = jx.step_current( - i_delay=0.1, i_dur=3.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) comp.stimulate(current) @@ -57,7 +57,7 @@ def test_type_optimizer(SimpleComp): comp.insert(HH()) comp.record() current = jx.step_current( - i_delay=0.1, i_dur=3.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 9676097e..6c623283 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -71,7 +71,7 @@ def test_record_synaptic_and_membrane_states(SimpleNet): fully_connect(net.cell([2]), net.cell([0]), IonotropicSynapse()) current = jx.step_current( - i_delay=1.0, i_dur=80.0, i_amp=0.02, delta_t=0.025, t_max=100.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) net.cell(0).branch(0).loc(0.0).stimulate(current) diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 3835f3e9..14f0bcb9 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -203,7 +203,7 @@ def test_shared_state(): for comp in [comp1, comp2, comp3]: comp.record() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=3.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) comp.stimulate(current) diff --git a/tests/test_solver.py b/tests/test_solver.py index 5e6de2b7..8d9160e9 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -31,7 +31,7 @@ def test_fwd_euler_and_crank_nicolson(SimpleNet): net = SimpleNet(2, 1, 4, connect=True) current = jx.step_current( - i_delay=0.5, i_dur=0.5, i_amp=0.1, delta_t=0.025, t_max=10.0 + i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 ) net.cell(0).branch(0).comp(0).stimulate(current) net.cell(1).branch(0).comp(3).record() diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 547120e7..3281db34 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -159,7 +159,7 @@ def test_set_and_insert(SimpleBranch, SimpleCell, SimpleNet): # test insert multiple stimuli single_current = jx.step_current( - i_delay=10.0, i_dur=80.0, i_amp=5.0, delta_t=0.025, t_max=100.0 + i_delay=0.0, i_dur=0.0, i_amp=0.0, delta_t=0.025, t_max=0.1 ) batch_of_currents = np.vstack([single_current for _ in range(4)]) From 098dfaa698c29c16a0d26632608bdabece471d8c Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 20 Nov 2024 16:32:29 +0100 Subject: [PATCH 13/15] fix: redo currents --- tests/test_api_equivalence.py | 14 +++++++------- tests/test_cell_matches_branch.py | 2 +- tests/test_channels.py | 4 ++-- tests/test_composability_of_modules.py | 6 +++--- tests/test_data_feeding.py | 10 +++++----- tests/test_grad.py | 4 ++-- tests/test_make_trainable.py | 8 ++++---- tests/test_optimize.py | 4 ++-- tests/test_record_and_stimulate.py | 6 +++--- tests/test_shared_state.py | 2 +- tests/test_solver.py | 2 +- tests/test_viewing.py | 2 +- 12 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py index fa3ceaf2..9fef8759 100644 --- a/tests/test_api_equivalence.py +++ b/tests/test_api_equivalence.py @@ -39,7 +39,7 @@ def test_api_equivalence_morphology(SimpleComp): cell2.branch(2).loc(0.4).record() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) cell1.branch(1).loc(1.0).stimulate(current) cell2.branch(1).loc(1.0).stimulate(current) @@ -56,7 +56,7 @@ def test_solver_backends_comp(SimpleComp): comp = SimpleComp() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) comp.record() @@ -74,7 +74,7 @@ def test_solver_backends_branch(SimpleBranch): branch = SimpleBranch(4) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) branch.loc(0.0).stimulate(current) branch.loc(0.5).record() @@ -93,7 +93,7 @@ def test_solver_backends_cell(SimpleCell): cell = SimpleCell(4, 4) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) cell.branch(0).loc(0.0).stimulate(current) cell.branch(0).loc(0.5).record() @@ -123,7 +123,7 @@ def test_solver_backends_net(SimpleNet): ) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() @@ -163,7 +163,7 @@ def test_api_equivalence_synapses(SimpleNet): for net in [net1, net2]: current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) net.cell(0).branch(0).loc(0.0).stimulate(current) net.cell(0).branch(0).loc(0.5).record() @@ -196,7 +196,7 @@ def test_api_equivalence_network_matches_cell(SimpleBranch): This runs an unequal number of compartments per branch.""" dt = 0.025 # ms current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) branch1 = SimpleBranch(nseg=1) diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index e415cf68..d87d8ab8 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -60,7 +60,7 @@ def test_equivalence(SimpleBranch, SimpleCell): dt = 0.025 t_max = 5.0 # ms current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) l1, g1 = _run_long_branch(dt, t_max, current, SimpleBranch(8)) l2, g2 = _run_short_branches(dt, t_max, current, SimpleCell(2, 4)) diff --git a/tests/test_channels.py b/tests/test_channels.py index c8e725ad..44a30dae 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -274,7 +274,7 @@ def test_init_states_complex_channel(SimpleCell): cell.init_states() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) cell.branch(2).comp(0).stimulate(current) cell.branch(2).comp(0).record() @@ -335,7 +335,7 @@ def compute_current(self, states, v, params): t_max = 2.0 # ms cell = SimpleCell(1, 1) cell.branch(0).loc(0.0).stimulate( - jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0) ) cell.insert(User()) diff --git a/tests/test_composability_of_modules.py b/tests/test_composability_of_modules.py index 9b834293..66f3457e 100644 --- a/tests/test_composability_of_modules.py +++ b/tests/test_composability_of_modules.py @@ -16,7 +16,7 @@ def test_compose_branch(): """Test inserting to comp and composing to branch equals inserting to branch.""" dt = 0.025 current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp1 = jx.Compartment() @@ -43,7 +43,7 @@ def test_compose_cell(): nseg_per_branch = 4 dt = 0.025 current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp = jx.Compartment() @@ -72,7 +72,7 @@ def test_compose_net(): nseg_per_branch = 4 dt = 0.025 current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp = jx.Compartment() diff --git a/tests/test_data_feeding.py b/tests/test_data_feeding.py index 62375cb2..1f6b5cdd 100644 --- a/tests/test_data_feeding.py +++ b/tests/test_data_feeding.py @@ -25,13 +25,13 @@ def test_constant_and_data_stimulus(SimpleCell): i_amps_data = jnp.asarray([0.01, 0.005]) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=i_amp_const, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=i_amp_const, delta_t=0.025, t_max=5.0 ) cell.branch(1).loc(0.6).stimulate(current) def provide_data(i_amps): current = jx.datapoint_to_step_currents( - i_delay=0.1, i_dur=1.0, i_amp=i_amps, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=i_amps, delta_t=0.025, t_max=5.0 ) data_stimuli = None data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[0], data_stimuli) @@ -48,7 +48,7 @@ def simulate(i_amps): cell.delete_stimuli() i_amp_summed = i_amp_const + jnp.sum(i_amps_data) current_sum = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=5.0 ) cell.branch(1).loc(0.6).stimulate(current_sum) @@ -65,7 +65,7 @@ def test_data_vs_constant_stimulus(SimpleCell): i_amps_data = jnp.asarray([0.01, 0.005]) def provide_data(i_amps): - current = jx.datapoint_to_step_currents(0.1, 1.0, i_amps, 0.025, 2.0) + current = jx.datapoint_to_step_currents(0.5, 1.0, i_amps, 0.025, 5.0) data_stimuli = None data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[0], data_stimuli) data_stimuli = cell.branch(1).loc(0.6).data_stimulate(current[1], data_stimuli) @@ -81,7 +81,7 @@ def simulate(i_amps): cell.delete_stimuli() i_amp_summed = jnp.sum(i_amps_data) current_sum = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=i_amp_summed, delta_t=0.025, t_max=5.0 ) cell.branch(1).loc(0.6).stimulate(current_sum) diff --git a/tests/test_grad.py b/tests/test_grad.py index 3cbcebde..b41b5d37 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -26,7 +26,7 @@ def simulate_with_params(params): comp.insert(HH()) comp.record() comp.stimulate( - jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0) ) val = 0.2 if key == "HH_m" else -70.0 @@ -62,7 +62,7 @@ def simulate_with_params(params): branch = SimpleBranch(4) branch.loc(0.0).record() branch.loc(0.0).stimulate( - jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0) ) branch.loc(0.0).insert(HH()) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 3f138b19..db909a29 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -322,7 +322,7 @@ def test_make_trainable_corresponds_to_set_pospischil(SimpleNet): net2.cell(1).branch(1).loc(0.0).record() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) net1.cell(0).branch(1).loc(0.0).stimulate(current) net2.cell(0).branch(1).loc(0.0).stimulate(current) @@ -389,7 +389,7 @@ def test_data_set_vs_make_trainable_pospischil(SimpleNet): net2.cell(1).branch(1).loc(0.0).record() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) net1.cell(0).branch(1).loc(0.0).stimulate(current) net2.cell(0).branch(1).loc(0.0).stimulate(current) @@ -402,7 +402,7 @@ def test_data_set_vs_make_trainable_network(SimpleNet): net1 = SimpleNet(2, 4, 1) net2 = SimpleNet(2, 4, 1) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) for net in [net1, net2]: net.insert(HH()) @@ -490,7 +490,7 @@ def test_write_trainables(SimpleNet): net.cell(0).branch(0).comp(0).record() net.cell(1).branch(0).comp(0).record() net.cell(0).branch(0).comp(0).stimulate( - jx.step_current(i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0) + jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0) ) net.make_trainable("radius") diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 5668e5e8..b2f5fdf8 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -23,7 +23,7 @@ def test_type_optimizer_api(SimpleComp): comp.insert(HH()) comp.record() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) @@ -57,7 +57,7 @@ def test_type_optimizer(SimpleComp): comp.insert(HH()) comp.record() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 6c623283..151c3474 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -24,7 +24,7 @@ def test_record_and_stimulate_api(SimpleCell): cell.branch(1).loc(1.0).record() current = jx.step_current( - i_delay=0.0, i_dur=0.0, i_amp=0.0, delta_t=0.025, t_max=0.1 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) cell.branch(1).loc(1.0).stimulate(current) @@ -37,7 +37,7 @@ def test_record_shape(SimpleCell): cell = SimpleCell(3, 2) current = jx.step_current( - i_delay=0.0, i_dur=0.0, i_amp=0.0, delta_t=0.025, t_max=0.1 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) cell.branch(1).loc(1.0).stimulate(current) @@ -71,7 +71,7 @@ def test_record_synaptic_and_membrane_states(SimpleNet): fully_connect(net.cell([2]), net.cell([0]), IonotropicSynapse()) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) net.cell(0).branch(0).loc(0.0).stimulate(current) diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 14f0bcb9..0de88bb5 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -203,7 +203,7 @@ def test_shared_state(): for comp in [comp1, comp2, comp3]: comp.record() current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) diff --git a/tests/test_solver.py b/tests/test_solver.py index 8d9160e9..99577e69 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -31,7 +31,7 @@ def test_fwd_euler_and_crank_nicolson(SimpleNet): net = SimpleNet(2, 1, 4, connect=True) current = jx.step_current( - i_delay=0.1, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=2.0 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) net.cell(0).branch(0).comp(0).stimulate(current) net.cell(1).branch(0).comp(3).record() diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 3281db34..5ebf65d1 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -159,7 +159,7 @@ def test_set_and_insert(SimpleBranch, SimpleCell, SimpleNet): # test insert multiple stimuli single_current = jx.step_current( - i_delay=0.0, i_dur=0.0, i_amp=0.0, delta_t=0.025, t_max=0.1 + i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) batch_of_currents = np.vstack([single_current for _ in range(4)]) From 6508cf2809ff4ea7bde6a8ed4bc5378897eb07e3 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 20 Nov 2024 16:53:38 +0100 Subject: [PATCH 14/15] fix: fix tests not passing --- tests/test_channels.py | 2 +- tests/test_optimize.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_channels.py b/tests/test_channels.py index 44a30dae..f5bdd2b1 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -332,7 +332,7 @@ def compute_current(self, states, v, params): return 0.01 * jnp.ones_like(v) dt = 0.025 # ms - t_max = 2.0 # ms + t_max = 5.0 # ms cell = SimpleCell(1, 1) cell.branch(0).loc(0.0).stimulate( jx.step_current(i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0) diff --git a/tests/test_optimize.py b/tests/test_optimize.py index b2f5fdf8..d094604e 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -23,7 +23,7 @@ def test_type_optimizer_api(SimpleComp): comp.insert(HH()) comp.record() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + i_delay=0.1, i_dur=3.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) @@ -57,7 +57,7 @@ def test_type_optimizer(SimpleComp): comp.insert(HH()) comp.record() current = jx.step_current( - i_delay=0.5, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=5.0 + i_delay=0.1, i_dur=3.0, i_amp=0.1, delta_t=0.025, t_max=5.0 ) comp.stimulate(current) From 3f7769708223749cb60d7938b5bd2f10d4253637 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 20 Nov 2024 17:06:39 +0100 Subject: [PATCH 15/15] fix: rebase and add fixes --- tests/test_plotting_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index 2193bd31..c3857215 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -23,6 +23,7 @@ def test_cell(SimpleMorphCell): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", "morph.swc") cell = SimpleMorphCell(fname, nseg=1) + cell.branch(0).set_ncomp(2) # test inhomogeneous ncomp # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) @@ -81,10 +82,11 @@ def test_network(SimpleMorphCell): ax = net.excitatory.vis() -def test_vis_networks_built_from_scartch(SimpleComp, SimpleBranch, SimpleCell): +def test_vis_networks_built_from_scratch(SimpleComp, SimpleBranch, SimpleCell): comp = SimpleComp(copy=True) branch = SimpleBranch(4) cell = SimpleCell(5, 3) + cell.branch(0).set_ncomp(3) # test inhomogeneous ncomp net = jx.Network([cell, cell]) connect( @@ -162,7 +164,9 @@ def test_volume_plotting( comp = SimpleComp() branch = SimpleBranch(2) cell = SimpleCell(2, 2) + cell.branch(0).set_ncomp(3) # test inhomogeneous ncomp net = SimpleNet(2, 2, 2) + for module in [comp, branch, cell, net]: module.compute_xyz()