Skip to content

Commit

Permalink
Add support for graph operations (#355)
Browse files Browse the repository at this point in the history
* add: add to_graph method

* add: add from_graph method

* fix: isort and rm obs import

* doc: added more comments

* wip: started adding new swc import function

* wip: save dev wip

* add: implement graph2jaxley method, that imposes branch structure on some graph

* fix: small fixes

* enh: add prelim version of xyzr

* wip: save wip

* enh: refactor of swc->graph->jaxley->module pipeline

* enh: cleanup

* fix: small fix, now from graph runs on compartmentalized morphology

* doc: add docstrings and type hints

* doc: more comments added

* fix: small fixes.

* fix: rm undefined from groups at import

* wip: save wip

* wip: save wip

* rm: remove dev notebook from tracking

* add: add tests for graph functionalities

* fix: added more tests and they are now passing

* fix: test fixes

* wip: save wip working on graph and swc io

* fix: radius import fixed

* wip: save wip

* enh: massive overhaul complete of graph pipeline compared to before. now much simpler and neuron comparison tests pass.

* wip: save wip

* wip: save wip

* fix: some fixes added

* wip: save wip

* wip: save wip tests

* wip: test look better than before.

* wip: remove complexity and improve test MSE. still not there though

* wip: save wip.

* rm: rm notebooks from pr

* wip: save wip

* enh: small refactor of swc -> initial graph pipe

* enh: massive overhaul auf graph pipe. add documentation. passes tests now

* wip: start adding tests.

* enh: incl graphIO in tutorial

* wip: progress on tests and tutorial

* wip: tests are passing. except for voltages, which only passes in notebook but not in pytest

* rm: dev notebook removed

* fix: tests passing for non-single soma morpho

* enh: Tests are finally passing

* fix: add misssing kwarg in simulate_trace_error

* fix: fix diff with main

* fix: rm diff in modules/base

* fix: add __eq__ back in for comparisons of cells attr in net and fix asteric

* chore: ran black

* fix: change read_swc to io imports

* chore: add license header

* enh: fixup of wording

* fix: fix merge artefact

* rm: rm tutorials left from rebase

* wip: swc pipe now works up until module import

* wip: in/out pipeline working for morphology, but w.o. attrs like recordings.

* fix: update tutorial

* fix: fix tutorial and synapses and input output

* fix: add networkx as dep

* fix rebase io/swc.py to main

* doc: add more function documentation

* wip: step 1 on getting tests to pass

* fix: reduce diff

* fix: ammend last commit

* fix: pop l when converting graph

* wip: working on tests

* wip: all but one test passing, working on import export cycle validation

* fix: finished import export tests and all tests are passing

* fix: speed up tests and add docs

* chore: edit changelog

* fix: fix import cycle error, add min_radius and make some functions private

* add: add new test morphology

* fix: refactor swc and combine both swc_readers into a combined method

* wip: change root finder and fix testcase

* fix: made more methods private, fix issues with 0 length edges in graph

* chore: edit changelog

* chore: rfmt changelog

* fix: skip new morph test for now
jnsbck authored Jan 31, 2025

Verified

This commit was signed with the committer’s verified signature.
booc0mtaco Holloway
1 parent e6f7ab0 commit 0753c2e
Showing 14 changed files with 12,871 additions and 292 deletions.
21 changes: 18 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -10,9 +10,6 @@ Installing `Jaxley` will no longer install the newest version of `JAX`. We reali
```python
net.record("i_IonotropicSynapse")
```
- Add regression tests and supporting workflows for maintaining baselines (#475, #546, @jnsbck).
- Regression tests can be triggered by commenting "/test_regression" on a PR.
- Regression tests can be done locally by running `NEW_BASELINE=1 pytest -m regression` i.e. on `main` and then `pytest -m regression` on `feature`, which will produce a test report (printed to the console and saved to .txt).

- refactor plotting (#539, @jnsbck).
- rm networkx dependency
@@ -25,8 +22,26 @@ net.vis()

- Allow parameter sharing for groups of different sizes, i.e. due to inhomogenous numbers of compartments or for synapses with the same (pre-)synaptic parameters but different numbers of post-synaptic partners. (#514, @jnsbck)

- Add `jaxley.io.graph` for exporting and importing of jaxley modules to and from `networkx` graph objects (#355, @jnsbck).
- Adds a new (and improved) SWC reader, which is more flexible and should also be easier to extend in the future.
```python
from jaxley.io.graph import swc_to_graph, from_graph
graph = swc_to_graph(fname)
# do something to the swc graph, i.e. prune it
pruned_graph = do_something_to_graph(graph)
cell = from_graph(pruned_graph, ncomp=4)
```
- Adds a new `to_graph` method for jaxley modules, which exports a module to a `networkX` graph. This allows to seamlessly work with `networkX`'s graph manipulation or visualization functions.
- `"graph"` can now also be selected as a backend in the `read_swc`.
- See [the improved SWC reader tutorial](https://jaxley.readthedocs.io/en/latest/tutorials/08_importing_morphologies.html) for more details.

### Code Health
- changelog added to CI (#537, #558, @jnsbck)

- Add regression tests and supporting workflows for maintaining baselines (#475, #546, @jnsbck).
- Regression tests can be triggered by commenting "/test_regression" on a PR.
- Regression tests can be done locally by running `NEW_BASELINE=1 pytest -m regression` i.e. on `main` and then `pytest -m regression` on `feature`, which will produce a test report (printed to the console and saved to .txt).

### Bug fixes
- Fixed inconsistency with *type* assertions arising due to `numpy` functions returning different `dtypes` on platforms like Windows (#567, @Kartik-Sama)

282 changes: 269 additions & 13 deletions docs/tutorials/08_importing_morphologies.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions jaxley/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from jaxley.io.graph import from_graph, to_graph
from jaxley.io.swc import read_swc
916 changes: 916 additions & 0 deletions jaxley/io/graph.py

Large diffs are not rendered by default.

321 changes: 311 additions & 10 deletions jaxley/io/swc.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,270 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from copy import copy
from functools import partial
from typing import Callable, List, Optional, Tuple
from warnings import warn

import jax.numpy as jnp
import numpy as np

from jaxley.io.graph import from_graph, swc_to_graph
from jaxley.modules import Branch, Cell, Compartment
from jaxley.utils.cell_utils import (
_build_parents,
_compute_pathlengths,
_padded_radius_generating_fn,
_radius_generating_fns,
_split_into_branches_and_sort,
build_radiuses_from_xyzr,
)
from jaxley.utils.cell_utils import build_radiuses_from_xyzr
from jaxley.utils.misc_utils import deprecated_kwargs


def _split_long_branches(
branches: np.ndarray,
types: np.ndarray,
content: np.ndarray,
max_branch_len: float,
is_single_point_soma: bool,
) -> Tuple[np.ndarray, np.ndarray]:
pathlengths = _compute_pathlengths(
branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
)
pathlengths = [np.sum(length_traced) for length_traced in pathlengths]
split_branches = []
split_types = []
for branch, type, length in zip(branches, types, pathlengths):
num_subbranches = 1
split_branch = [branch]
while length > max_branch_len:
num_subbranches += 1
split_branch = _split_branch_equally(branch, num_subbranches)
lengths_of_subbranches = _compute_pathlengths(
split_branch,
coords=content[:, 1:6],
is_single_point_soma=is_single_point_soma,
)
lengths_of_subbranches = [
np.sum(length_traced) for length_traced in lengths_of_subbranches
]
length = max(lengths_of_subbranches)
if num_subbranches > 10:
warn(
"""`num_subbranches > 10`, stopping to split. Most likely your
SWC reconstruction is not dense and some neighbouring traced
points are farther than `max_branch_len` apart."""
)
break
split_branches += split_branch
split_types += [type] * num_subbranches

return split_branches, split_types


def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]:
num_points_each = len(branch) // num_subbranches
branches = [branch[:num_points_each]]
for i in range(1, num_subbranches - 1):
branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each])
branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :])
return branches


def _split_into_branches(
content: np.ndarray, is_single_point_soma: bool
) -> Tuple[np.ndarray, np.ndarray]:
prev_ind = None
prev_type = None
n_branches = 0

# Branch inds will contain the row identifier at which a branch point occurs
# (i.e. the row of the parent of two branches).
branch_inds = []
for c in content:
current_ind = c[0]
current_parent = c[-1]
current_type = c[1]
if current_parent != prev_ind or current_type != prev_type:
branch_inds.append(int(current_parent))
n_branches += 1
prev_ind = current_ind
prev_type = current_type

all_branches = []
current_branch = []
all_types = []

# Loop over every line in the SWC file.
for c in content:
current_ind = c[0] # First col is row_identifier
current_parent = c[-1] # Last col is parent in SWC specification.
if current_parent == -1:
all_types.append(c[1])
else:
current_type = c[1]

if current_parent == -1 and is_single_point_soma and current_ind == 1:
all_branches.append([int(current_ind)])
all_types.append(int(current_type))

# Either append the current point to the branch, or add the branch to
# `all_branches`.
if current_parent in branch_inds[1:]:
if len(current_branch) > 1:
all_branches.append(current_branch)
all_types.append(current_type)
current_branch = [int(current_parent), int(current_ind)]
else:
current_branch.append(int(current_ind))

# Append the final branch (intermediate branches are already appended five lines
# above.)
all_branches.append(current_branch)
return all_branches, all_types


def _split_into_branches_and_sort(
content: np.ndarray,
max_branch_len: Optional[float],
is_single_point_soma: bool,
sort: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
branches, types = _split_into_branches(content, is_single_point_soma)
if max_branch_len is not None:
branches, types = _split_long_branches(
branches,
types,
content,
max_branch_len,
is_single_point_soma=is_single_point_soma,
)

if sort:
first_val = np.asarray([b[0] for b in branches])
sorting = np.argsort(first_val, kind="mergesort")
sorted_branches = [branches[s] for s in sorting]
sorted_types = [types[s] for s in sorting]
else:
sorted_branches = branches
sorted_types = types
return sorted_branches, sorted_types


def _radius_generating_fns(
all_branches: np.ndarray,
radiuses: np.ndarray,
each_length: np.ndarray,
parents: np.ndarray,
types: np.ndarray,
) -> List[Callable]:
"""For all branches in a cell, returns callable that return radius given loc."""
radius_fns = []
for i, branch in enumerate(all_branches):
rads_in_branch = radiuses[np.asarray(branch) - 1]
if parents[i] > -1 and types[i] != types[parents[i]]:
# We do not want to linearly interpolate between the radius of the previous
# branch if a new type of neurite is found (e.g. switch from soma to
# apical). From looking at the SWC from n140.swc I believe that this is
# also what NEURON does.
rads_in_branch[0] = rads_in_branch[1]
radius_fn = _radius_generating_fn(
radiuses=rads_in_branch, each_length=each_length[i]
)
# Beause SWC starts counting at 1, but numpy counts from 0.
# ind_of_branch_endpoint = np.asarray(b) - 1
radius_fns.append(radius_fn)
return radius_fns


def _padded_radius(loc: float, radiuses: np.ndarray) -> float:
return radiuses * np.ones_like(loc)


def _radius(loc: float, cutoffs: np.ndarray, radiuses: np.ndarray) -> float:
"""Function which returns the radius via linear interpolation.
Defined outside of `_radius_generating_fns` to allow for pickling of the resulting
Cell object."""
index = np.digitize(loc, cutoffs, right=False)
left_rad = radiuses[index - 1]
right_rad = radiuses[index]
left_loc = cutoffs[index - 1]
right_loc = cutoffs[index]
loc_within_bin = (loc - left_loc) / (right_loc - left_loc)
return left_rad + (right_rad - left_rad) * loc_within_bin


def _padded_radius_generating_fn(radiuses: np.ndarray) -> Callable:
return partial(_padded_radius, radiuses=radiuses)


def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable:
# Avoid division by 0 with the `summed_len` below.
each_length[each_length < 1e-8] = 1e-8
summed_len = np.sum(each_length)
cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len
cutoffs[0] -= 1e-8
cutoffs[-1] += 1e-8

# We have to linearly interpolate radiuses, therefore we need at least two radiuses.
# However, jaxley allows somata which consist of a single traced point (i.e.
# just one radius). Therefore, we just `tile` in order to generate an artificial
# endpoint and startpoint radius of the soma.
if len(radiuses) == 1:
radiuses = np.tile(radiuses, 2)

return partial(_radius, cutoffs=cutoffs, radiuses=radiuses)


def _build_parents(all_branches: List[np.ndarray]) -> List[int]:
parents = [None] * len(all_branches)
all_last_inds = [b[-1] for b in all_branches]
for i, branch in enumerate(all_branches):
parent_ind = branch[0]
ind = np.where(np.asarray(all_last_inds) == parent_ind)[0]
if len(ind) > 0 and ind != i:
parents[i] = ind[0]
else:
assert (
parent_ind == 1
), """Trying to connect a segment to the beginning of
another segment. This is not allowed. Please create an issue on github."""
parents[i] = -1

return parents


def _compute_pathlengths(
all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool
) -> List[np.ndarray]:
"""
Args:
coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius).
"""
branch_pathlengths = []
for b in all_branches:
coords_in_branch = coords[np.asarray(b) - 1]
if len(coords_in_branch) > 1:
# If the branch starts at a different neurite (e.g. the soma) then NEURON
# ignores the distance from that initial point. To reproduce, use the
# following SWC dummy file and read it in NEURON (and Jaxley):
# 1 1 0.00 0.0 0.0 6.0 -1
# 2 2 9.00 0.0 0.0 0.5 1
# 3 2 10.0 0.0 0.0 0.3 2
types = coords_in_branch[:, 0]
if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma:
coords_in_branch[0] = coords_in_branch[1]

# Compute distances between all traced points in a branch.
point_diffs = np.diff(coords_in_branch, axis=0)
dists = np.sqrt(
point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2
)
else:
# Jaxley uses length and radius for every compartment and assumes the
# surface area to be 2*pi*r*length. For branches consisting of a single
# traced point we assume for them to have area 4*pi*r*r. Therefore, we have
# to set length = 2*r.
radius = coords_in_branch[0, 4] # txyzr -> 4 is radius.
dists = np.asarray([2 * radius])
branch_pathlengths.append(dists)
return branch_pathlengths


def swc_to_jaxley(
fname: str,
max_branch_len: Optional[float] = None,
@@ -95,7 +340,7 @@ def swc_to_jaxley(


@deprecated_kwargs("0.6.0", ["nseg"])
def read_swc(
def read_swc_custom(
fname: str,
ncomp: Optional[int] = None,
nseg: Optional[int] = None,
@@ -179,3 +424,59 @@ def read_swc(
if len(indices) > 0:
cell.branch(indices).add_to_group(name)
return cell


@deprecated_kwargs("0.6.0", ["nseg"])
def read_swc(
fname: str,
ncomp: Optional[int] = None,
nseg: Optional[int] = None,
max_branch_len: Optional[float] = None,
min_radius: Optional[float] = None,
assign_groups: bool = True,
backend: str = "custom",
) -> Cell:
"""Reads SWC file into a `Cell`.
Jaxley assumes cylindrical compartments and therefore defines length and radius
for every compartment. The surface area is then 2*pi*r*length. For branches
consisting of a single traced point we assume for them to have area 4*pi*r*r.
Therefore, in these cases, we set lenght=2*r.
Args:
fname: Path to the swc file.
ncomp: The number of compartments per branch.
nseg: Deprecated. Use `ncomp` instead.
max_branch_len: If a branch is longer than this value it is split into two
branches.
min_radius: If the radius of a reconstruction is below this value it is clipped.
assign_groups: If True, then the identity of reconstructed points in the SWC
file will be used to generate groups `undefined`, `soma`, `axon`, `basal`,
`apical`, `custom`. See here:
http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
backend: The backend to use. Currently `custom` and `graph` are supported.
For context on these backends see `read_swc_custom` and `from_graph`.
Returns:
A `Cell` object."""

if backend == "custom":
return read_swc_custom(
fname,
ncomp=ncomp,
nseg=nseg,
max_branch_len=max_branch_len,
min_radius=min_radius,
assign_groups=assign_groups,
)
elif backend == "graph":
graph = swc_to_graph(fname)
return from_graph(
graph,
ncomp=ncomp,
max_branch_len=max_branch_len,
min_radius=min_radius,
assign_groups=assign_groups,
)
else:
raise ValueError(f"Unknown backend: {backend}. Use either `custom` or `graph`.")
15 changes: 9 additions & 6 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
@@ -2521,12 +2521,15 @@ def _set_inds_in_view(
incl_comps = pointer.nodes.loc[
self._nodes_in_view, "global_comp_index"
].unique()
pre = base_edges["pre_global_comp_index"].isin(incl_comps).to_numpy()
post = base_edges["post_global_comp_index"].isin(incl_comps).to_numpy()
possible_edges_in_view = base_edges.index.to_numpy()[(pre & post).flatten()]
self._edges_in_view = np.intersect1d(
possible_edges_in_view, self._edges_in_view
)
if not base_edges.empty:
pre = base_edges["pre_global_comp_index"].isin(incl_comps).to_numpy()
post = base_edges["post_global_comp_index"].isin(incl_comps).to_numpy()
possible_edges_in_view = base_edges.index.to_numpy()[
(pre & post).flatten()
]
self._edges_in_view = np.intersect1d(
possible_edges_in_view, self._edges_in_view
)
elif not has_node_inds and has_edge_inds:
base_nodes = self.base.nodes
self._edges_in_view = edges
257 changes: 2 additions & 255 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
@@ -1,270 +1,17 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from math import pi
from typing import Callable, Dict, List, Optional, Tuple, Union
from warnings import warn
from typing import Callable, Dict, List, Optional, Tuple

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit, vmap
from jax import vmap

from jaxley.utils.misc_utils import cumsum_leading_zero


def _split_into_branches_and_sort(
content: np.ndarray,
max_branch_len: Optional[float],
is_single_point_soma: bool,
sort: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
branches, types = _split_into_branches(content, is_single_point_soma)
if max_branch_len is not None:
branches, types = _split_long_branches(
branches,
types,
content,
max_branch_len,
is_single_point_soma=is_single_point_soma,
)

if sort:
first_val = np.asarray([b[0] for b in branches])
sorting = np.argsort(first_val, kind="mergesort")
sorted_branches = [branches[s] for s in sorting]
sorted_types = [types[s] for s in sorting]
else:
sorted_branches = branches
sorted_types = types
return sorted_branches, sorted_types


def _split_long_branches(
branches: np.ndarray,
types: np.ndarray,
content: np.ndarray,
max_branch_len: float,
is_single_point_soma: bool,
) -> Tuple[np.ndarray, np.ndarray]:
pathlengths = _compute_pathlengths(
branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
)
pathlengths = [np.sum(length_traced) for length_traced in pathlengths]
split_branches = []
split_types = []
for branch, type, length in zip(branches, types, pathlengths):
num_subbranches = 1
split_branch = [branch]
while length > max_branch_len:
num_subbranches += 1
split_branch = _split_branch_equally(branch, num_subbranches)
lengths_of_subbranches = _compute_pathlengths(
split_branch,
coords=content[:, 1:6],
is_single_point_soma=is_single_point_soma,
)
lengths_of_subbranches = [
np.sum(length_traced) for length_traced in lengths_of_subbranches
]
length = max(lengths_of_subbranches)
if num_subbranches > 10:
warn(
"""`num_subbranches > 10`, stopping to split. Most likely your
SWC reconstruction is not dense and some neighbouring traced
points are farther than `max_branch_len` apart."""
)
break
split_branches += split_branch
split_types += [type] * num_subbranches

return split_branches, split_types


def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]:
num_points_each = len(branch) // num_subbranches
branches = [branch[:num_points_each]]
for i in range(1, num_subbranches - 1):
branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each])
branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :])
return branches


def _split_into_branches(
content: np.ndarray, is_single_point_soma: bool
) -> Tuple[np.ndarray, np.ndarray]:
prev_ind = None
prev_type = None
n_branches = 0

# Branch inds will contain the row identifier at which a branch point occurs
# (i.e. the row of the parent of two branches).
branch_inds = []
for c in content:
current_ind = c[0]
current_parent = c[-1]
current_type = c[1]
if current_parent != prev_ind or current_type != prev_type:
branch_inds.append(int(current_parent))
n_branches += 1
prev_ind = current_ind
prev_type = current_type

all_branches = []
current_branch = []
all_types = []

# Loop over every line in the SWC file.
for c in content:
current_ind = c[0] # First col is row_identifier
current_parent = c[-1] # Last col is parent in SWC specification.
if current_parent == -1:
all_types.append(c[1])
else:
current_type = c[1]

if current_parent == -1 and is_single_point_soma and current_ind == 1:
all_branches.append([int(current_ind)])
all_types.append(int(current_type))

# Either append the current point to the branch, or add the branch to
# `all_branches`.
if current_parent in branch_inds[1:]:
if len(current_branch) > 1:
all_branches.append(current_branch)
all_types.append(current_type)
current_branch = [int(current_parent), int(current_ind)]
else:
current_branch.append(int(current_ind))

# Append the final branch (intermediate branches are already appended five lines
# above.)
all_branches.append(current_branch)
return all_branches, all_types


def _build_parents(all_branches: List[np.ndarray]) -> List[int]:
parents = [None] * len(all_branches)
all_last_inds = [b[-1] for b in all_branches]
for i, branch in enumerate(all_branches):
parent_ind = branch[0]
ind = np.where(np.asarray(all_last_inds) == parent_ind)[0]
if len(ind) > 0 and ind != i:
parents[i] = ind[0]
else:
assert (
parent_ind == 1
), """Trying to connect a segment to the beginning of
another segment. This is not allowed. Please create an issue on github."""
parents[i] = -1

return parents


def _radius_generating_fns(
all_branches: np.ndarray,
radiuses: np.ndarray,
each_length: np.ndarray,
parents: np.ndarray,
types: np.ndarray,
) -> List[Callable]:
"""For all branches in a cell, returns callable that return radius given loc."""
radius_fns = []
for i, branch in enumerate(all_branches):
rads_in_branch = radiuses[np.asarray(branch) - 1]
if parents[i] > -1 and types[i] != types[parents[i]]:
# We do not want to linearly interpolate between the radius of the previous
# branch if a new type of neurite is found (e.g. switch from soma to
# apical). From looking at the SWC from n140.swc I believe that this is
# also what NEURON does.
rads_in_branch[0] = rads_in_branch[1]
radius_fn = _radius_generating_fn(
radiuses=rads_in_branch, each_length=each_length[i]
)
# Beause SWC starts counting at 1, but numpy counts from 0.
# ind_of_branch_endpoint = np.asarray(b) - 1
radius_fns.append(radius_fn)
return radius_fns


def _padded_radius(loc: float, radiuses: np.ndarray) -> float:
return radiuses * np.ones_like(loc)


def _radius(loc: float, cutoffs: np.ndarray, radiuses: np.ndarray) -> float:
"""Function which returns the radius via linear interpolation.
Defined outside of `_radius_generating_fns` to allow for pickling of the resulting
Cell object."""
index = np.digitize(loc, cutoffs, right=False)
left_rad = radiuses[index - 1]
right_rad = radiuses[index]
left_loc = cutoffs[index - 1]
right_loc = cutoffs[index]
loc_within_bin = (loc - left_loc) / (right_loc - left_loc)
return left_rad + (right_rad - left_rad) * loc_within_bin


def _padded_radius_generating_fn(radiuses: np.ndarray) -> Callable:
return partial(_padded_radius, radiuses=radiuses)


def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable:
# Avoid division by 0 with the `summed_len` below.
each_length[each_length < 1e-8] = 1e-8
summed_len = np.sum(each_length)
cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len
cutoffs[0] -= 1e-8
cutoffs[-1] += 1e-8

# We have to linearly interpolate radiuses, therefore we need at least two radiuses.
# However, jaxley allows somata which consist of a single traced point (i.e.
# just one radius). Therefore, we just `tile` in order to generate an artificial
# endpoint and startpoint radius of the soma.
if len(radiuses) == 1:
radiuses = np.tile(radiuses, 2)

return partial(_radius, cutoffs=cutoffs, radiuses=radiuses)


def _compute_pathlengths(
all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool
) -> List[np.ndarray]:
"""
Args:
coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius).
"""
branch_pathlengths = []
for b in all_branches:
coords_in_branch = coords[np.asarray(b) - 1]
if len(coords_in_branch) > 1:
# If the branch starts at a different neurite (e.g. the soma) then NEURON
# ignores the distance from that initial point. To reproduce, use the
# following SWC dummy file and read it in NEURON (and Jaxley):
# 1 1 0.00 0.0 0.0 6.0 -1
# 2 2 9.00 0.0 0.0 0.5 1
# 3 2 10.0 0.0 0.0 0.3 2
types = coords_in_branch[:, 0]
if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma:
coords_in_branch[0] = coords_in_branch[1]

# Compute distances between all traced points in a branch.
point_diffs = np.diff(coords_in_branch, axis=0)
dists = np.sqrt(
point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2
)
else:
# Jaxley uses length and radius for every compartment and assumes the
# surface area to be 2*pi*r*length. For branches consisting of a single
# traced point we assume for them to have area 4*pi*r*r. Therefore, we have
# to set length = 2*r.
radius = coords_in_branch[0, 4] # txyzr -> 4 is radius.
dists = np.asarray([2 * radius])
branch_pathlengths.append(dists)
return branch_pathlengths


def build_radiuses_from_xyzr(
radius_fns: List[Callable],
branch_indices: List[int],
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ dependencies = [
"numpy",
"pandas>=2.2.0",
"tridiax",
"networkx",
]

[project.optional-dependencies]
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -153,6 +153,7 @@ def get_or_build_cell(
fname: Optional[str] = None,
ncomp: int = 1,
max_branch_len: float = 2_000.0,
swc_backend: str = "custom",
copy: bool = True,
force_init: bool = False,
) -> jx.Cell:
@@ -175,7 +176,11 @@ def get_or_build_cell(
fname = default_fname if fname is None else fname
if key := (fname, ncomp, max_branch_len) not in cells or force_init:
cells[key] = jx.read_swc(
fname, ncomp=ncomp, max_branch_len=max_branch_len, assign_groups=True
fname,
ncomp=ncomp,
max_branch_len=max_branch_len,
assign_groups=True,
backend=swc_backend,
)
return deepcopy(cells[key]) if copy and not force_init else cells[key]

158 changes: 158 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import numpy as np
import pandas as pd


def get_segment_xyzrL(section, comp_idx=None, loc=None, nseg=8):
assert (
comp_idx is not None or loc is not None
), "Either comp_idx or loc must be provided."
assert not (
comp_idx is not None and loc is not None
), "Only one of comp_idx or loc can be provided."

comp_len = 1 / nseg
loc = comp_len / 2 + comp_idx * comp_len if loc is None else loc

n3d = section.n3d()
x3d = np.array([section.x3d(i) for i in range(n3d)])
y3d = np.array([section.y3d(i) for i in range(n3d)])
z3d = np.array([section.z3d(i) for i in range(n3d)])
L = np.array([section.arc3d(i) for i in range(n3d)]) # Cumulative arc lengths
r3d = np.array([section.diam3d(i) / 2 for i in range(n3d)])
if loc is None:
return x3d, y3d, z3d, r3d
else:
total_length = L[-1]
target_length = loc * total_length

# Find segment containing target_length
for i in range(1, n3d):
if L[i] >= target_length:
break
else:
i = n3d - 1

# Interpolate between points i-1 and i
L0, L1 = L[i - 1], L[i]
t = (target_length - L0) / (L1 - L0)
x = x3d[i - 1] + t * (x3d[i] - x3d[i - 1])
y = y3d[i - 1] + t * (y3d[i] - y3d[i - 1])
z = z3d[i - 1] + t * (z3d[i] - z3d[i - 1])
r = r3d[i - 1] + t * (r3d[i] - r3d[i - 1])
return x, y, z, r, L[-1] / nseg


def jaxley2neuron_by_coords(jx_cell, neuron_secs, comp_idx=None, loc=None, nseg=8):
neuron_coords = {
i: np.vstack(get_segment_xyzrL(sec, comp_idx=comp_idx, loc=loc, nseg=nseg))[
:3
].T
for i, sec in enumerate(neuron_secs)
}
neuron_coords = np.vstack(
[np.hstack([k * np.ones((v.shape[0], 1)), v]) for k, v in neuron_coords.items()]
)
neuron_coords = pd.DataFrame(
neuron_coords, columns=["global_branch_index", "x", "y", "z"]
)
neuron_coords["global_branch_index"] = neuron_coords["global_branch_index"].astype(
int
)

neuron_loc_xyz = neuron_coords.groupby("global_branch_index").mean()
jaxley_loc_xyz = (
jx_cell.branch("all")
.loc(loc)
.nodes.set_index("global_branch_index")[["x", "y", "z"]]
)

jaxley2neuron_inds = {}
for i, xyz in enumerate(jaxley_loc_xyz.to_numpy()):
d = np.sqrt(((neuron_loc_xyz - xyz) ** 2)).sum(axis=1)
jaxley2neuron_inds[i] = d.argmin()
return jaxley2neuron_inds


def jaxley2neuron_by_group(
jx_cell,
neuron_secs,
comp_idx=None,
loc=None,
nseg=8,
num_apical=20,
num_tuft=20,
num_basal=10,
):
y_apical = (
jx_cell.apical.nodes.groupby("global_branch_index")
.mean()["y"]
.abs()
.sort_values()
)
trunk_inds = y_apical.index[:num_apical].tolist()
tuft_inds = y_apical.index[-num_tuft:].tolist()
basal_inds = (
jx_cell.basal.nodes["global_branch_index"].unique()[:num_basal].tolist()
)

jaxley2neuron = jaxley2neuron_by_coords(
jx_cell, neuron_secs, comp_idx=comp_idx, loc=loc, nseg=nseg
)

neuron_trunk_inds = [jaxley2neuron[i] for i in trunk_inds]
neuron_tuft_inds = [jaxley2neuron[i] for i in tuft_inds]
neuron_basal_inds = [jaxley2neuron[i] for i in basal_inds]

neuron_inds = {
"trunk": neuron_trunk_inds,
"tuft": neuron_tuft_inds,
"basal": neuron_basal_inds,
}
jaxley_inds = {"trunk": trunk_inds, "tuft": tuft_inds, "basal": basal_inds}
return neuron_inds, jaxley_inds


def match_stim_loc(jx_cell, neuron_sec, comp_idx=None, loc=None, nseg=8):
stim_coords = get_segment_xyzrL(neuron_sec, comp_idx=comp_idx, loc=loc, nseg=nseg)[
:3
]
stim_idx = (
((jx_cell.nodes[["x", "y", "z"]] - stim_coords) ** 2).sum(axis=1).argmin()
)
return stim_idx


def import_neuron_morph(fname, nseg=8):
from neuron import h

_ = h.load_file("stdlib.hoc")
_ = h.load_file("import3d.hoc")
nseg = 8

##################### NEURON ##################
for sec in h.allsec():
h.delete_section(sec=sec)

cell = h.Import3d_SWC_read()
cell.input(fname)
i3d = h.Import3d_GUI(cell, False)
i3d.instantiate(None)

for sec in h.allsec():
sec.nseg = nseg
return h, cell


def equal_both_nan_or_empty_df(a, b):
if a.empty and b.empty:
return True
a[a.isna()] = -1
b[b.isna()] = -1
if set(a.columns) != set(b.columns):
return False
else:
a = a[b.columns]
return (a == b).all()
178 changes: 177 additions & 1 deletion tests/jaxley_identical/test_swc.py
Original file line number Diff line number Diff line change
@@ -15,9 +15,10 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jaxley_mech.channels.l5pc import *

import jaxley as jx
from jaxley.channels import HH
from jaxley.channels import HH, K, Leak, Na
from jaxley.synapses import IonotropicSynapse


@@ -152,3 +153,178 @@ def test_swc_net(voltage_solver: str, SimpleMorphCell):
max_error = np.max(np.abs(voltages[:, ::20] - voltages_300724))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


# This test will be skipped for now, due to weird quirk of the swc file, which has two
# different radii for the same xyz coordinates (L10364: 10362 -> 10549). This is handled
# by the differently by the two swc reader backends (graph seems to be the correct one,
# compared to NEURON).
@pytest.mark.skip
@pytest.mark.slow
@pytest.mark.parametrize("swc_backend", ["custom", "graph"])
def test_swc_morph(swc_backend, SimpleMorphCell):
gt_apical = {}
gt_soma = {}
gt_axon = {}

gt_apical["apical_NaTs2T_gNaTs2T"] = 0.026145
gt_apical["apical_SKv3_1_gSKv3_1"] = 0.004226
gt_apical["apical_M_gM"] = 0.000143

gt_soma["somatic_NaTs2T_gNaTs2T"] = 0.983955
gt_soma["somatic_SKv3_1_gSKv3_1"] = 0.303472
gt_soma["somatic_SKE2_gSKE2"] = 0.008407
gt_soma["somatic_CaPump_gamma"] = 0.000609
gt_soma["somatic_CaPump_decay"] = 210.485291
gt_soma["somatic_CaHVA_gCaHVA"] = 0.000994
gt_soma["somatic_CaLVA_gCaLVA"] = 0.000333

gt_axon["axonal_NaTaT_gNaTaT"] = 3.137968
gt_axon["axonal_KPst_gKPst"] = 0.973538
gt_axon["axonal_KTst_gKTst"] = 0.089259
gt_axon["axonal_SKE2_gSKE2"] = 0.007104
gt_axon["axonal_SKv3_1_gSKv3_1"] = 1.021945
gt_axon["axonal_CaHVA_gCaHVA"] = 0.00099
gt_axon["axonal_CaLVA_gCaLVA"] = 0.008752
gt_axon["axonal_CaPump_gamma"] = 0.00291
gt_axon["axonal_CaPump_decay"] = 287.19873

dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "../swc_files", "bbp_with_axon.swc") # n120
cell = SimpleMorphCell(fname, ncomp=2, swc_backend=swc_backend)

# custom swc reader does not label the root branch that is added to the soma
# while the graph swc reader does. This is accounted for here.
cell.groups["soma"] = (
cell.groups["soma"][2:] if swc_backend == "graph" else cell.groups["soma"]
)
apical_inds = cell.groups["apical"]

########## APICAL ##########
cell.apical.set("capacitance", 2.0)
cell.apical.insert(NaTs2T().change_name("apical_NaTs2T"))
cell.apical.insert(SKv3_1().change_name("apical_SKv3_1"))
cell.apical.insert(M().change_name("apical_M"))
cell.apical.insert(H().change_name("apical_H"))

for c in apical_inds:
distance = cell.scope("global").comp(c).distance(cell.branch(1).loc(0.0))
cond = (-0.8696 + 2.087 * np.exp(distance * 0.0031)) * 8e-5
cell.scope("global").comp(c).set("apical_H_gH", cond)

########## SOMA ##########
cell.soma.insert(NaTs2T().change_name("somatic_NaTs2T"))
cell.soma.insert(SKv3_1().change_name("somatic_SKv3_1"))
cell.soma.insert(SKE2().change_name("somatic_SKE2"))
ca_dynamics = CaNernstReversal()
ca_dynamics.channel_constants["T"] = 307.15
cell.soma.insert(ca_dynamics)
cell.soma.insert(CaPump().change_name("somatic_CaPump"))
cell.soma.insert(CaHVA().change_name("somatic_CaHVA"))
cell.soma.insert(CaLVA().change_name("somatic_CaLVA"))
cell.soma.set("CaCon_i", 5e-05)
cell.soma.set("CaCon_e", 2.0)

########## BASAL ##########
cell.basal.insert(H().change_name("basal_H"))
cell.basal.set("basal_H_gH", 8e-5)

# ########## AXON ##########
cell.insert(CaNernstReversal())
cell.set("CaCon_i", 5e-05)
cell.set("CaCon_e", 2.0)
cell.axon.insert(NaTaT().change_name("axonal_NaTaT"))
cell.axon.insert(KTst().change_name("axonal_KTst"))
cell.axon.insert(CaPump().change_name("axonal_CaPump"))
cell.axon.insert(SKE2().change_name("axonal_SKE2"))
cell.axon.insert(CaHVA().change_name("axonal_CaHVA"))
cell.axon.insert(KPst().change_name("axonal_KPst"))
cell.axon.insert(SKv3_1().change_name("axonal_SKv3_1"))
cell.axon.insert(CaLVA().change_name("axonal_CaLVA"))

########## WHOLE CELL ##########
cell.insert(Leak())
cell.set("Leak_gLeak", 3e-05)
cell.set("Leak_eLeak", -75.0)

cell.set("axial_resistivity", 100.0)
cell.set("eNa", 50.0)
cell.set("eK", -85.0)
cell.set("v", -65.0)

for key in gt_apical.keys():
cell.apical.set(key, gt_apical[key])

for key in gt_soma.keys():
cell.soma.set(key, gt_soma[key])

for key in gt_axon.keys():
cell.axon.set(key, gt_axon[key])

dt = 0.025
t_max = 100.0
time_vec = np.arange(0, t_max + 2 * dt, dt)

cell.delete_stimuli()
cell.delete_recordings()

i_delay = 10.0
i_dur = 80.0
i_amp = 3.0
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
cell.branch(1).comp(0).stimulate(current)
cell.branch(1).comp(0).record()

cell.set("v", -65.0)
cell.init_states()

voltages = jx.integrate(cell)

voltages_250130 = jnp.asarray(
[
-65.0,
-66.22422623,
-67.23001452,
-68.06298803,
-68.75766951,
-33.91317711,
-55.24503749,
-46.11452291,
-42.18960646,
-51.12861864,
-43.65442616,
-40.62727385,
-49.56110473,
-43.24030949,
-36.71731271,
-48.7405489,
-42.98507829,
-34.64282586,
-48.24427898,
-42.6412365,
-34.70568206,
-47.90643598,
-42.15688181,
-36.17711814,
-47.65564274,
-41.52265914,
-38.1627371,
-47.44680473,
-40.70730741,
-40.15298353,
-47.25483146,
-39.63994798,
-41.96818737,
-47.06569105,
-38.17257448,
-43.50053648,
-46.87517934,
-65.40488865,
-69.96981343,
-72.24384111,
-73.46204372,
]
)
max_error = np.max(np.abs(voltages[:, ::100] - voltages_250130))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
10,619 changes: 10,619 additions & 0 deletions tests/swc_files/bbp_with_axon.swc

Large diffs are not rendered by default.

378 changes: 378 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,378 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import os
from copy import deepcopy

import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

import jax.numpy as jnp
import networkx as nx
import numpy as np
import pandas as pd
import pytest

import jaxley as jx
from jaxley import connect
from jaxley.channels import HH
from jaxley.channels.pospischil import K, Leak, Na
from jaxley.io.graph import (
add_missing_graph_attrs,
from_graph,
make_jaxley_compatible,
swc_to_graph,
to_graph,
trace_branches,
)
from jaxley.synapses import IonotropicSynapse, TestSynapse

# from jaxley.utils.misc_utils import recursive_compare
from tests.helpers import (
equal_both_nan_or_empty_df,
get_segment_xyzrL,
import_neuron_morph,
jaxley2neuron_by_group,
match_stim_loc,
)


# test exporting and re-importing of different modules
def test_graph_import_export_cycle(
SimpleComp, SimpleBranch, SimpleCell, SimpleNet, SimpleMorphCell
):
np.random.seed(0)
comp = SimpleComp()
branch = SimpleBranch(4)
cell = SimpleCell(5, 4)
morph_cell = SimpleMorphCell(ncomp=1)
net = SimpleNet(3, 5, 4)

# add synapses
connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse())
connect(net[0, 0, 1], net[1, 0, 1], IonotropicSynapse())
# connect(net[0, 0, 1], net[1, 0, 1], TestSynapse()) # makes test fail, see warning w. synapses = True

# add groups
net.cell(2).add_to_group("cell2")
net.cell(2).branch(1).add_to_group("cell2branch1")

# add ion channels
net.cell(0).insert(Na())
net.cell(0).insert(Leak())
net.cell(1).branch(1).insert(Na())
net.cell(0).insert(K())

# test consistency of exported and re-imported modules
for module in [comp, branch, cell, net, morph_cell]:
module.compute_xyz() # ensure x,y,z in nodes b4 exporting for later comparison
module_graph = to_graph(
module, channels=True, synapses=True
) # ensure to_graph works
re_module = from_graph(module_graph) # ensure prev exported graph can be read
re_module_graph = to_graph(
re_module, channels=True, synapses=True
) # ensure to_graph works for re-imported modules

# ensure original module and re-imported module are equal
assert np.all(equal_both_nan_or_empty_df(re_module.nodes, module.nodes))
assert np.all(equal_both_nan_or_empty_df(re_module.edges, module.edges))
assert np.all(
equal_both_nan_or_empty_df(re_module.branch_edges, module.branch_edges)
)

for k in module.groups:
assert k in re_module.groups
assert np.all(re_module.groups[k] == module.groups[k])

for re_xyzr, xyzr in zip(re_module.xyzr, module.xyzr):
re_xyzr[np.isnan(re_xyzr)] = -1
xyzr[np.isnan(xyzr)] = -1

assert np.all(re_xyzr == xyzr)

re_imported_mechs = re_module.channels + re_module.synapses
for re_mech, mech in zip(re_imported_mechs, module.channels + module.synapses):
assert np.all(re_mech.name == mech.name)

# ensure exported graph and re-exported graph are equal
node_df = pd.DataFrame(
[d for i, d in module_graph.nodes(data=True)], index=module_graph.nodes
).sort_index()
re_node_df = pd.DataFrame(
[d for i, d in re_module_graph.nodes(data=True)],
index=re_module_graph.nodes,
).sort_index()
assert np.all(equal_both_nan_or_empty_df(node_df, re_node_df))

edges = pd.DataFrame(
[
{
"pre_global_comp_index": i,
"post_global_comp_index": j,
**module_graph.edges[i, j],
}
for (i, j) in module_graph.edges
]
)
re_edges = pd.DataFrame(
[
{
"pre_global_comp_index": i,
"post_global_comp_index": j,
**re_module_graph.edges[i, j],
}
for (i, j) in re_module_graph.edges
]
)
assert np.all(equal_both_nan_or_empty_df(edges, re_edges))

# ignore "externals", "recordings", "trainable_params", "indices_set_by_trainables"
for k in ["ncomp", "xyzr"]:
assert module_graph.graph[k] == re_module_graph.graph[k]

# assume if module can be integrated, so can be comp, cell and branch
if isinstance(module, jx.Network):
# test integration of re-imported module
re_module.select(nodes=0).record(verbose=False)
jx.integrate(re_module, t_max=0.5)


@pytest.mark.parametrize(
"file", ["morph_single_point_soma.swc", "morph.swc", "bbp_with_axon.swc"]
)
def test_trace_branches(file):
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", file)
graph = swc_to_graph(fname)

# pre-processing
graph = add_missing_graph_attrs(graph)
graph = trace_branches(graph, None, ignore_swc_trace_errors=False)

edges = pd.DataFrame([{"u": u, "v": v, **d} for u, v, d in graph.edges(data=True)])
nx_branch_lens = edges.groupby("branch_index")["l"].sum().to_numpy()
nx_branch_lens = np.sort(nx_branch_lens)

# exclude artificial root branch
if np.isclose(nx_branch_lens[0], 1e-1):
nx_branch_lens = nx_branch_lens[1:]

h, _ = import_neuron_morph(fname)
neuron_branch_lens = np.sort([sec.L for sec in h.allsec()])

errors = np.abs(neuron_branch_lens - nx_branch_lens)
# one error is expected, see https://github.com/jaxleyverse/jaxley/issues/140
assert sum(errors > 1e-3) <= 1


@pytest.mark.parametrize(
"file", ["morph_single_point_soma.swc", "morph.swc", "bbp_with_axon.swc"]
)
def test_from_graph_vs_NEURON(file):
ncomp = 8
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", file)

graph = swc_to_graph(fname)
cell = from_graph(
graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False
)
cell.compute_compartment_centers()
h, neuron_cell = import_neuron_morph(fname, nseg=ncomp)

# remove root branch
jaxley_comps = cell.nodes[
~np.isclose(cell.nodes["length"], 0.1 / ncomp)
].reset_index(drop=True)

jx_branch_lens = (
jaxley_comps.groupby("global_branch_index")["length"].sum().to_numpy()
)

# match by branch lengths
neuron_xyzd = [np.array(s.psection()["morphology"]["pts3d"]) for s in h.allsec()]
neuron_branch_lens = np.array(
[
np.sqrt((np.diff(n[:, :3], axis=0) ** 2).sum(axis=1)).sum()
for n in neuron_xyzd
]
)
neuron_inds = np.argsort(neuron_branch_lens)
jx_inds = np.argsort(jx_branch_lens)

neuron_df = pd.DataFrame(columns=["neuron_idx", "x", "y", "z", "radius", "length"])
jx_df = pd.DataFrame(columns=["jx_idx", "x", "y", "z", "radius", "length"])
for k in range(len(neuron_inds)):
neuron_comp_k = np.array(
[
get_segment_xyzrL(list(h.allsec())[neuron_inds[k]], comp_idx=i)
for i in range(ncomp)
]
)
# make this a dataframe
neuron_comp_k = pd.DataFrame(
neuron_comp_k, columns=["x", "y", "z", "radius", "length"]
)
neuron_comp_k["idx"] = neuron_inds[k]
jx_comp_k = jaxley_comps[jaxley_comps["global_branch_index"] == jx_inds[k]][
["x", "y", "z", "radius", "length"]
]
jx_comp_k["idx"] = jx_inds[k]
neuron_df = pd.concat([neuron_df, neuron_comp_k], axis=0, ignore_index=True)
jx_df = pd.concat([jx_df, jx_comp_k], axis=0, ignore_index=True)

errors = neuron_df["neuron_idx"].to_frame()
errors["jx_idx"] = jx_df["jx_idx"]
errors[["x", "y", "z"]] = neuron_df[["x", "y", "z"]] - jx_df[["x", "y", "z"]]
errors["xyz"] = np.sqrt((errors[["x", "y", "z"]] ** 2).sum(axis=1))
errors["radius"] = neuron_df["radius"] - jx_df["radius"]
errors["length"] = neuron_df["length"] - jx_df["length"]

# one error is expected, see https://github.com/jaxleyverse/jaxley/issues/140
assert sum(errors.groupby("jx_idx")["xyz"].max() > 1e-3) <= 1
assert sum(errors.groupby("jx_idx")["radius"].max() > 1e-3) <= 1
assert sum(errors.groupby("jx_idx")["length"].max() > 1e-3) <= 1


def test_edges_only_to_jaxley():
# test if edge graph can pe imported into to jaxley
sets_of_edges = [
[(0, 1), (1, 2), (2, 3)],
[(0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)],
]
for edges in sets_of_edges:
edge_graph = nx.DiGraph(edges)
edge_module = from_graph(edge_graph)


@pytest.mark.slow
@pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"])
def test_swc2graph_voltages(file):
"""Check if voltages of SWC recording match.
To match the branch indices between NEURON and jaxley, we rely on comparing the
length of the branches.
It tests whether, on average over time and recordings, the voltage is off by less
than 1.5 mV.
"""
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", file) # n120

ncomp = 8

i_delay = 2.0
i_dur = 5.0
i_amp = 0.25
t_max = 20.0
dt = 0.025

##################### NEURON ##################
h, neuron_cell = import_neuron_morph(fname, nseg=ncomp)

####################### jaxley ##################
graph = swc_to_graph(fname)
jx_cell = from_graph(
graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False
)
jx_cell.compute_compartment_centers()
jx_cell.insert(HH())

branch_loc = 0.05
neuron_inds, jaxley_inds = jaxley2neuron_by_group(
jx_cell, h.allsec(), loc=branch_loc
)
trunk_inds, tuft_inds, basal_inds = [
jaxley_inds[key] for key in ["trunk", "tuft", "basal"]
]
neuron_trunk_inds, neuron_tuft_inds, neuron_basal_inds = [
neuron_inds[key] for key in ["trunk", "tuft", "basal"]
]

stim_loc = 0.1
stim_idx = match_stim_loc(jx_cell, h.soma[0], loc=stim_loc)

jx_cell.set("axial_resistivity", 1_000.0)
jx_cell.set("v", -62.0)
jx_cell.set("HH_m", 0.074901)
jx_cell.set("HH_h", 0.4889)
jx_cell.set("HH_n", 0.3644787)

jx_cell.select(stim_idx).stimulate(
jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
)
for i in trunk_inds + tuft_inds + basal_inds:
jx_cell.branch(i).loc(branch_loc).record()

voltages_jaxley = jx.integrate(jx_cell, delta_t=dt)

################### NEURON #################
stim = h.IClamp(h.soma[0](stim_loc))
stim.delay = i_delay
stim.dur = i_dur
stim.amp = i_amp

counter = 0
voltage_recs = {}

for r in neuron_trunk_inds:
for i, sec in enumerate(h.allsec()):
if i == r:
v = h.Vector()
v.record(sec(branch_loc)._ref_v)
voltage_recs[f"v{counter}"] = v
counter += 1

for r in neuron_tuft_inds:
for i, sec in enumerate(h.allsec()):
if i == r:
v = h.Vector()
v.record(sec(branch_loc)._ref_v)
voltage_recs[f"v{counter}"] = v
counter += 1

for r in neuron_basal_inds:
for i, sec in enumerate(h.allsec()):
if i == r:
v = h.Vector()
v.record(sec(branch_loc)._ref_v)
voltage_recs[f"v{counter}"] = v
counter += 1

for sec in h.allsec():
sec.insert("hh")
sec.Ra = 1_000.0

sec.gnabar_hh = 0.120 # S/cm2
sec.gkbar_hh = 0.036 # S/cm2
sec.gl_hh = 0.0003 # S/cm2
sec.ena = 50 # mV
sec.ek = -77.0 # mV
sec.el_hh = -54.3 # mV

h.dt = dt
tstop = t_max
v_init = -62.0

def initialize():
h.finitialize(v_init)
h.fcurrent()

def integrate():
while h.t < tstop:
h.fadvance()

initialize()
integrate()
voltages_neuron = np.asarray([voltage_recs[key] for key in voltage_recs])

####################### check ################
errors = np.mean(np.abs(voltages_jaxley - voltages_neuron), axis=1)

assert all(errors < 2.5), "voltages do not match."
5 changes: 2 additions & 3 deletions tests/test_swc.py
Original file line number Diff line number Diff line change
@@ -237,8 +237,7 @@ def integrate():
initialize()
integrate()
voltages_neuron = np.asarray([voltage_recs[key] for key in voltage_recs])
errors = np.mean(np.abs(voltages_jaxley - voltages_neuron), axis=1)

####################### check ################
assert np.mean(
np.abs(voltages_jaxley - voltages_neuron) < 1.5
), "voltages do not match."
assert all(errors < 2.5), "voltages do not match."

0 comments on commit 0753c2e

Please sign in to comment.