diff --git a/jaxley/build_branched_tridiag.py b/jaxley/build_branched_tridiag.py
deleted file mode 100644
index 74c6f27d..00000000
--- a/jaxley/build_branched_tridiag.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
-# licensed under the Apache License Version 2.0, see
-
-from math import pi
-
-import jax.numpy as jnp
-from jax import lax, vmap
-
-
-def define_all_tridiags(
- voltages: jnp.ndarray,
- voltage_terms: jnp.asarray,
- i_ext: jnp.ndarray,
- num_branches: int,
- coupling_conds_upper: float,
- coupling_conds_lower: float,
- summed_coupling_conds: float,
- dt: float,
-):
- """
- Set up tridiagonal system for each branch.
- """
- voltages = jnp.reshape(voltages, (num_branches, -1))
-
- voltage_terms = jnp.reshape(voltage_terms, (num_branches, -1))
- i_ext = jnp.reshape(i_ext, (num_branches, -1))
-
- lowers, diags, uppers, solves = vmap(
- _define_tridiag_for_branch, in_axes=(0, 0, 0, None, 0, 0, 0)
- )(
- voltages,
- voltage_terms,
- i_ext,
- dt,
- coupling_conds_upper,
- coupling_conds_lower,
- summed_coupling_conds,
- )
-
- return (lowers, diags, uppers, solves)
-
-
-def _define_tridiag_for_branch(
- voltages: jnp.ndarray,
- voltage_terms: jnp.ndarray,
- i_ext: jnp.ndarray,
- dt: float,
- coupling_conds_upper: float,
- coupling_conds_lower: float,
- summed_coupling_conds: float,
-):
- """
- Defines the tridiagonal system to solve for a single branch.
- """
-
- # Diagonal and solve.
- diags = 1.0 + dt * voltage_terms + dt * summed_coupling_conds
- solves = voltages + dt * i_ext
-
- # Subdiagonals.
- upper = jnp.asarray(-dt * coupling_conds_upper)
- lower = jnp.asarray(-dt * coupling_conds_lower)
- return lower, diags, upper, solves
diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py
index f9eb93ec..42da6653 100644
--- a/jaxley/modules/base.py
+++ b/jaxley/modules/base.py
@@ -7,7 +7,6 @@
from typing import Callable, Dict, List, Optional, Tuple, Union
import jax.numpy as jnp
-import networkx as nx
import numpy as np
import pandas as pd
from jax import jit, vmap
@@ -17,16 +16,16 @@
from jaxley.channels import Channel
from jaxley.solver_voltage import (
step_voltage_explicit,
- step_voltage_implicit_with_custom_spsolve,
step_voltage_implicit_with_jax_spsolve,
+ step_voltage_implicit_with_jaxley_spsolve,
)
from jaxley.synapses import Synapse
from jaxley.utils.cell_utils import (
_compute_index_of_child,
_compute_num_children,
+ compute_axial_conductances,
compute_levels,
convert_point_process_to_distributed,
- convert_to_csc,
interpolate_xyz,
loc_of_index,
query_channel_states_and_params,
@@ -35,6 +34,7 @@
from jaxley.utils.debug_solver import compute_morphology_indices
from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
from jaxley.utils.plot_utils import plot_morph
+from jaxley.utils.solver_utils import convert_to_csc
class Module(ABC):
@@ -258,7 +258,7 @@ def _show(
def init_morph(self):
"""Initialize the morphology such that it can be processed by the solvers."""
- self._init_morph_custom_spsolve()
+ self._init_morph_jaxley_spsolve()
self._init_morph_jax_spsolve()
self.initialized_morph = True
@@ -268,36 +268,13 @@ def _init_morph_jax_spsolve(self):
raise NotImplementedError
@abstractmethod
- def _init_morph_custom_spsolve(self):
+ def _init_morph_jaxley_spsolve(self):
"""Initialize the morphology for the custom Jaxley solver."""
raise NotImplementedError
- def init_conds(self, params: Dict, voltage_solver: str):
+ def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):
"""Given radius, length, r_a, compute the axial coupling conductances."""
- if voltage_solver.startswith("jaxley"):
- return self._init_conds_custom_spsolve(params)
- else:
- return self._init_conds_jax_spsolve(params)
-
- @abstractmethod
- def _init_conds_jax_spsolve(self, params: Dict):
- """Initialize coupling conductances.
-
- Args:
- params: Conductances and morphology parameters, not yet including
- coupling conductances.
- """
- raise NotImplementedError
-
- @abstractmethod
- def _init_conds_custom_spsolve(self, params: Dict):
- """Initialize coupling conductances.
-
- Args:
- params: Conductances and morphology parameters, not yet including
- coupling conductances.
- """
- raise NotImplementedError
+ return compute_axial_conductances(self._comp_edges, params)
def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"):
"""Adds channel nodes from constituents to `self.channel_nodes`."""
@@ -525,9 +502,9 @@ def get_all_parameters(
) -> Dict[str, jnp.ndarray]:
"""Return all parameters (and coupling conductances) needed to simulate.
- Runs `init_conds()` and return every parameter that is needed to solve the ODE.
- This includes conductances, radiuses, lengths, axial_resistivities, but also
- coupling conductances.
+ Runs `_compute_axial_conductances()` and return every parameter that is needed
+ to solve the ODE. This includes conductances, radiuses, lengths,
+ axial_resistivities, but also coupling conductances.
This is done by first obtaining the current value of every parameter (not only
the trainable ones) and then replacing the trainable ones with the value
@@ -576,11 +553,8 @@ def get_all_parameters(
# `.set()` to work. This is done with `[:, None]`.
params[key] = params[key].at[inds].set(set_param[:, None])
- # Compute conductance params and append them.
- cond_params = self.init_conds(params=params, voltage_solver=voltage_solver)
- for key in cond_params:
- params[key] = cond_params[key]
-
+ # Compute conductance params and add them to the params dictionary.
+ params["axial_conductances"] = self._compute_axial_conductances(params=params)
return params
def get_states_from_nodes_and_edges(self):
@@ -963,19 +937,26 @@ def step(
# Voltage steps.
cm = params["capacitance"] # Abbreviation.
+ # Arguments used by all solvers.
+ solver_kwargs = {
+ "voltages": voltages,
+ "voltage_terms": (v_terms + syn_v_terms) / cm,
+ "constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
+ "axial_conductances": params["axial_conductances"],
+ "internal_node_inds": self._internal_node_inds,
+ }
+
+ # Add solver specific arguments.
if voltage_solver == "jax.sparse":
- solver_kwargs = {
- "voltages": voltages,
- "voltage_terms": (v_terms + syn_v_terms) / cm,
- "constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
- "axial_conductances": params["axial_conductances"],
- "data_inds": self._data_inds,
- "indices": self._indices_jax_spsolve,
- "indptr": self._indptr_jax_spsolve,
- "sinks": np.asarray(self._comp_edges["sink"].to_list()),
- "n_nodes": self._n_nodes,
- "internal_node_inds": self._internal_node_inds,
- }
+ solver_kwargs.update(
+ {
+ "sinks": np.asarray(self._comp_edges["sink"].to_list()),
+ "data_inds": self._data_inds,
+ "indices": self._indices_jax_spsolve,
+ "indptr": self._indptr_jax_spsolve,
+ "n_nodes": self._n_nodes,
+ }
+ )
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_jax_spsolve
else:
@@ -985,29 +966,27 @@ def step(
# Currently, the forward Euler solver also uses this format. However,
# this is only for historical reasons and we are planning to change this in
# the future.
- solver_kwargs = {
- "voltages": voltages,
- "voltage_terms": (v_terms + syn_v_terms) / cm,
- "constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
- "coupling_conds_upper": params["branch_uppers"],
- "coupling_conds_lower": params["branch_lowers"],
- "summed_coupling_conds": params["branch_diags"],
- "branchpoint_conds_children": params["branchpoint_conds_children"],
- "branchpoint_conds_parents": params["branchpoint_conds_parents"],
- "branchpoint_weights_children": params["branchpoint_weights_children"],
- "branchpoint_weights_parents": params["branchpoint_weights_parents"],
- "par_inds": self.par_inds,
- "child_inds": self.child_inds,
- "nbranches": self.total_nbranches,
- "solver": voltage_solver,
- "children_in_level": self.children_in_level,
- "parents_in_level": self.parents_in_level,
- "root_inds": self.root_inds,
- "branchpoint_group_inds": self.branchpoint_group_inds,
- "debug_states": self.debug_states,
- }
+ solver_kwargs.update(
+ {
+ "sinks": np.asarray(self._comp_edges["sink"].to_list()),
+ "sources": np.asarray(self._comp_edges["source"].to_list()),
+ "types": np.asarray(self._comp_edges["type"].to_list()),
+ "masked_node_inds": self._remapped_node_indices,
+ "nseg_per_branch": self.nseg_per_branch,
+ "nseg": self.nseg,
+ "par_inds": self.par_inds,
+ "child_inds": self.child_inds,
+ "nbranches": self.total_nbranches,
+ "solver": voltage_solver,
+ "children_in_level": self.children_in_level,
+ "parents_in_level": self.parents_in_level,
+ "root_inds": self.root_inds,
+ "branchpoint_group_inds": self.branchpoint_group_inds,
+ "debug_states": self.debug_states,
+ }
+ )
# Only for `bwd_euler` and `cranck-nicolson`.
- step_voltage_implicit = step_voltage_implicit_with_custom_spsolve
+ step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve
if solver == "bwd_euler":
u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)
diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py
index 6d5deeef..d888eb8a 100644
--- a/jaxley/modules/branch.py
+++ b/jaxley/modules/branch.py
@@ -3,20 +3,16 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
+from warnings import warn
import jax.numpy as jnp
import numpy as np
import pandas as pd
-from jax import vmap
from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.compartment import Compartment, CompartmentView
-from jaxley.utils.cell_utils import (
- comp_edges_to_indices,
- compute_axial_conductances,
- compute_children_and_parents,
- compute_coupling_cond,
-)
+from jaxley.utils.cell_utils import compute_children_and_parents
+from jaxley.utils.solver_utils import comp_edges_to_indices
class Branch(Module):
@@ -88,7 +84,7 @@ def __init__(
self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = (
compute_children_and_parents(self.branch_edges)
)
- self.root_inds = jnp.asarray([0])
+ self._internal_node_inds = jnp.arange(self.nseg)
self.initialize()
self.init_syns()
@@ -118,20 +114,81 @@ def __getattr__(self, key: str):
else:
raise KeyError(f"Key {key} not recognized.")
- def _init_morph_custom_spsolve(self):
+ def set_ncomp(self, ncomp: int):
+ """Set the number of compartments with which the branch is discretized."""
+ radius_generating_functions = lambda x: 0.5
+ within_branch_radiuses = self.nodes["radius"]
+
+ if ~np.all_equal(within_branch_radiuses) and radius_generating_functions is None:
+ warn(
+ f"You previously modified the radius of individual compartments, but now"
+ f"you are modifying the number of compartments in this branch. We are"
+ f"resetting every radius in this branch to 1um. To avoid this, first"
+ f"set the number of compartments in every branch and then modify their radius."
+ )
+ within_branch_radiuses = 1.0 * np.ones_like(within_branch_radiuses)
+
+ if ~np.all_equal(within_branch_lengths):
+ warn(
+ f"You previously modified the length of individual compartments, but now"
+ f"you are modifying the number of compartments in this branch. We are"
+ f"now assuming that the lenght of every compartment in this branch is equal,"
+ f"such that the branch has the same length as with the old number of compartments."
+ f"To avoid this, first set the number of compartments in every branch and then modify their radius."
+ )
+
+ # Compute new compartment lengths.
+ comp_lengths = np.sum(compartment_lengths) / ncomp
+
+ # Compute new compartment radiuses.
+ if radius_generating_functions is not None:
+ comp_radiuses = radius_generating_functions(np.linspace(0, 1, ncomps))
+ else:
+ comp_radiuses = within_branch_radiuses
+
+ # Add new row as the average of all rows.
+ df = self.nodes
+ average_row = df.mean(skipna=False)
+ average_row = average_row.to_frame().T
+ df = pd.concat([df, average_row], axis="rows")
+
+ # Set the correct datatype after having performed an average which cast
+ # everything to float.
+ integer_cols = ["comp_index", "branch_index", "cell_index"]
+ df[integer_cols] = df[integer_cols].astype(int)
+
+ # Update the comp_index, branch_index, cell_index.
+ # TODO.
+
+ # Special treatment for channels. Channels will only be added to the new nseg
+ # if **all** other segments in the branch also had that channel.
+ channel_cols = ["HH"]
+ df[channel_cols] = np.floor(df[channel_cols]).astype(bool)
+
+ # Special treatment for the lengths.
+ df["length"] = comp_lengths
+
+ # Special treatment for the radiuses.
+ df["radius"] = comp_radiuses
+
+
+ def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = np.asarray([]).astype(int)
+ self.root_inds = jnp.asarray([0])
+ self._remapped_node_indices = self._internal_node_inds
self.children_in_level = []
self.parents_in_level = []
def _init_morph_jax_spsolve(self):
"""Initialize morphology for the jax sparse voltage solver.
- Explanation of `type`:
- `type == 0`: compartment-to-compartment (within branch)
- `type == 1`: branchpoint-to-compartment
- `type == 2`: compartment-to-branchpoint
+ Explanation of `self._comp_eges['type']`:
+ `type == 0`: compartment <--> compartment (within branch)
+ `type == 1`: branchpoint --> parent-compartment
+ `type == 2`: branchpoint --> child-compartment
+ `type == 3`: parent-compartment --> branchpoint
+ `type == 4`: child-compartment --> branchpoint
"""
- self._internal_node_inds = jnp.arange(self.nseg)
self._comp_edges = pd.DataFrame().from_dict(
{
"source": list(range(self.nseg - 1)) + list(range(1, self.nseg)),
@@ -145,65 +202,6 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr
- def _init_conds_custom_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
- conds = self.init_branch_conds_custom_spsolve(
- params["axial_resistivity"], params["radius"], params["length"], self.nseg
- )
- cond_params = {
- "branchpoint_conds_children": jnp.asarray([]),
- "branchpoint_conds_parents": jnp.asarray([]),
- "branchpoint_weights_children": jnp.asarray([]),
- "branchpoint_weights_parents": jnp.asarray([]),
- }
- cond_params["branch_lowers"] = conds[0]
- cond_params["branch_uppers"] = conds[1]
- cond_params["branch_diags"] = conds[2]
-
- return cond_params
-
- def _init_conds_jax_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
- conds = compute_axial_conductances(self._comp_edges, params)
- return {"axial_conductances": conds}
-
- @staticmethod
- def init_branch_conds_custom_spsolve(
- axial_resistivity: jnp.ndarray,
- radiuses: jnp.ndarray,
- lengths: jnp.ndarray,
- nseg: int,
- ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
- """Given an axial resisitivity, set the coupling conductances.
-
- Args:
- axial_resistivity: Axial resistivity of each compartment.
- radiuses: Radius of each compartment.
- lengths: Length of each compartment.
- nseg: Number of compartments in the branch.
-
- Returns:
- Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.
- """
-
- # Compute coupling conductance for segments within a branch.
- # `radius`: um
- # `r_a`: ohm cm
- # `length_single_compartment`: um
- # `coupling_conds`: S * um / cm / um^2 = S / cm / um
- r1 = radiuses[:-1]
- r2 = radiuses[1:]
- r_a1 = axial_resistivity[:-1]
- r_a2 = axial_resistivity[1:]
- l1 = lengths[:-1]
- l2 = lengths[1:]
- coupling_conds_bwd = compute_coupling_cond(r1, r2, r_a1, r_a2, l1, l2)
- coupling_conds_fwd = compute_coupling_cond(r2, r1, r_a2, r_a1, l2, l1)
-
- # Compute the summed coupling conductances of each compartment.
- summed_coupling_conds = jnp.zeros((nseg))
- summed_coupling_conds = summed_coupling_conds.at[1:].add(coupling_conds_fwd)
- summed_coupling_conds = summed_coupling_conds.at[:-1].add(coupling_conds_bwd)
- return coupling_conds_fwd, coupling_conds_bwd, summed_coupling_conds
-
def __len__(self) -> int:
return self.nseg
diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py
index 8578e2bb..d8aa513b 100644
--- a/jaxley/modules/cell.py
+++ b/jaxley/modules/cell.py
@@ -1,35 +1,26 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see
-import time
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
import jax.numpy as jnp
import numpy as np
import pandas as pd
-from jax import vmap
-from jax.lax import ScatterDimensionNumbers, scatter_add
from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.branch import Branch, BranchView, Compartment
from jaxley.synapses import Synapse
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
- comp_edges_to_indices,
- compute_axial_conductances,
compute_children_and_parents,
compute_children_in_level,
compute_children_indices,
- compute_coupling_cond,
- compute_coupling_cond_branchpoint,
- compute_impact_on_node,
compute_levels,
compute_morphology_indices_in_levels,
compute_parents_in_level,
- loc_of_index,
- remap_to_consecutive,
)
+from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked
from jaxley.utils.swc import swc_to_jaxley
@@ -132,6 +123,7 @@ def __init__(
self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = (
compute_children_and_parents(self.branch_edges)
)
+ self._internal_node_inds = np.arange(self.cumsum_nseg[-1])
self.initialize()
self.init_syns()
@@ -157,11 +149,13 @@ def __getattr__(self, key: str):
else:
raise KeyError(f"Key {key} not recognized.")
- def _init_morph_custom_spsolve(self):
+ def _init_morph_jaxley_spsolve(self):
"""Initialize morphology for the custom sparse solver.
Running this function is only required for custom Jaxley solvers, i.e., for
- `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`.
+ `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at
+ `.__init__()` (when the function is run), we do not yet know which solver the
+ user will use. Therefore, we always run this function at `.__init__()`.
"""
children_and_parents = compute_morphology_indices_in_levels(
len(self.par_inds),
@@ -185,18 +179,28 @@ def _init_morph_custom_spsolve(self):
)
self.root_inds = jnp.asarray([0])
+ # Generate mapping to dealing with the masking which allows using the custom
+ # sparse solver to deal with different nseg per branch.
+ self._remapped_node_indices = remap_index_to_masked(
+ self._internal_node_inds,
+ self.nodes,
+ self.nseg,
+ self.nseg_per_branch,
+ )
+
def _init_morph_jax_spsolve(self):
"""For morphology indexing with the `jax.sparse` voltage volver.
- Explanation of `type`:
- `type == 0`: compartment-to-compartment (within branch)
- `type == 1`: branchpoint-to-compartment
- `type == 2`: compartment-to-branchpoint
+ Explanation of `self._comp_eges['type']`:
+ `type == 0`: compartment <--> compartment (within branch)
+ `type == 1`: branchpoint --> parent-compartment
+ `type == 2`: branchpoint --> child-compartment
+ `type == 3`: parent-compartment --> branchpoint
+ `type == 4`: child-compartment --> branchpoint
Running this function is only required for generic sparse solvers, i.e., for
`voltage_solver='jax.sparse'`.
"""
- self._internal_node_inds = np.arange(self.cumsum_nseg[-1])
# Edges between compartments within the branches.
# `[offset, offset, 0]` because we want to offset `source` and `sink`, but
@@ -212,13 +216,6 @@ def _init_morph_jax_spsolve(self):
del self.branch_list
# Edges from branchpoints to compartments.
- branchpoint_to_child_edges = pd.DataFrame().from_dict(
- {
- "source": self.child_belongs_to_branchpoint + self.cumsum_nseg[-1],
- "sink": self.cumsum_nseg[self.child_inds],
- "type": 1,
- }
- )
branchpoint_to_parent_edges = pd.DataFrame().from_dict(
{
"source": np.arange(len(self.par_inds)) + self.cumsum_nseg[-1],
@@ -226,6 +223,13 @@ def _init_morph_jax_spsolve(self):
"type": 1,
}
)
+ branchpoint_to_child_edges = pd.DataFrame().from_dict(
+ {
+ "source": self.child_belongs_to_branchpoint + self.cumsum_nseg[-1],
+ "sink": self.cumsum_nseg[self.child_inds],
+ "type": 2,
+ }
+ )
self._comp_edges = pd.concat(
[
self._comp_edges,
@@ -236,15 +240,14 @@ def _init_morph_jax_spsolve(self):
)
# Edges from compartments to branchpoints.
- child_to_branchpoint_edges = branchpoint_to_child_edges.rename(
+ parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename(
columns={"sink": "source", "source": "sink"}
)
- child_to_branchpoint_edges["type"] = 2
-
- parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename(
+ parent_to_branchpoint_edges["type"] = 3
+ child_to_branchpoint_edges = branchpoint_to_child_edges.rename(
columns={"sink": "source", "source": "sink"}
)
- parent_to_branchpoint_edges["type"] = 2
+ child_to_branchpoint_edges["type"] = 4
self._comp_edges = pd.concat(
[
@@ -261,79 +264,8 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr
- def _init_conds_custom_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
- """Given an axial resisitivity, set the coupling conductances."""
- nbranches = self.total_nbranches
- nseg = self.nseg
-
- axial_resistivity = jnp.reshape(params["axial_resistivity"], (nbranches, nseg))
- radiuses = jnp.reshape(params["radius"], (nbranches, nseg))
- lengths = jnp.reshape(params["length"], (nbranches, nseg))
-
- conds = vmap(Branch.init_branch_conds_custom_spsolve, in_axes=(0, 0, 0, None))(
- axial_resistivity, radiuses, lengths, self.nseg
- )
- coupling_conds_fwd = conds[0]
- coupling_conds_bwd = conds[1]
- summed_coupling_conds = conds[2]
-
- # The conductance from the children to the branch point.
- branchpoint_conds_children = vmap(
- compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)
- )(
- radiuses[self.child_inds, 0],
- axial_resistivity[self.child_inds, 0],
- lengths[self.child_inds, 0],
- )
- # The conductance from the parents to the branch point.
- branchpoint_conds_parents = vmap(
- compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)
- )(
- radiuses[self.par_inds, -1],
- axial_resistivity[self.par_inds, -1],
- lengths[self.par_inds, -1],
- )
-
- # Weights with which the compartments influence their nearby node.
- # The impact of the children on the branch point.
- branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(
- radiuses[self.child_inds, 0],
- axial_resistivity[self.child_inds, 0],
- lengths[self.child_inds, 0],
- )
- # The impact of parents on the branch point.
- branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(
- radiuses[self.par_inds, -1],
- axial_resistivity[self.par_inds, -1],
- lengths[self.par_inds, -1],
- )
-
- summed_coupling_conds = self.update_summed_coupling_conds_custom_spsolve(
- summed_coupling_conds,
- self.child_inds,
- self.par_inds,
- branchpoint_conds_children,
- branchpoint_conds_parents,
- )
-
- cond_params = {
- "branch_uppers": coupling_conds_bwd,
- "branch_lowers": coupling_conds_fwd,
- "branch_diags": summed_coupling_conds,
- "branchpoint_conds_children": branchpoint_conds_children,
- "branchpoint_conds_parents": branchpoint_conds_parents,
- "branchpoint_weights_children": branchpoint_weights_children,
- "branchpoint_weights_parents": branchpoint_weights_parents,
- }
- return cond_params
-
- def _init_conds_jax_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
- """Given length, radius, and r_a, set the coupling conductances."""
- conds = compute_axial_conductances(self._comp_edges, params)
- return {"axial_conductances": conds}
-
@staticmethod
- def update_summed_coupling_conds_custom_spsolve(
+ def update_summed_coupling_conds_jaxley_spsolve(
summed_conds,
child_inds,
par_inds,
diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py
index 2f8c320c..9a3571c3 100644
--- a/jaxley/modules/compartment.py
+++ b/jaxley/modules/compartment.py
@@ -10,12 +10,12 @@
from jaxley.modules.base import Module, View
from jaxley.utils.cell_utils import (
- comp_edges_to_indices,
compute_children_and_parents,
index_of_loc,
interpolate_xyz,
loc_of_index,
)
+from jaxley.utils.solver_utils import comp_edges_to_indices
class Compartment(Module):
@@ -57,7 +57,7 @@ def __init__(self):
self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = (
compute_children_and_parents(self.branch_edges)
)
- self.root_inds = jnp.asarray([0])
+ self._internal_node_inds = jnp.asarray([0])
# Initialize the module.
self.initialize()
@@ -66,20 +66,23 @@ def __init__(self):
# Coordinates.
self.xyzr = [float("NaN") * np.zeros((2, 4))]
- def _init_morph_custom_spsolve(self):
+ def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = np.asarray([]).astype(int)
+ self.root_inds = jnp.asarray([0])
+ self._remapped_node_indices = self._internal_node_inds
self.children_in_level = []
self.parents_in_level = []
def _init_morph_jax_spsolve(self):
"""Initialize morphology for the jax sparse voltage solver.
- Explanation of `type`:
- `type == 0`: compartment-to-compartment (within branch)
- `type == 1`: compartment-to-branchpoint
- `type == 2`: branchpoint-to-compartment
+ Explanation of `self._comp_eges['type']`:
+ `type == 0`: compartment <--> compartment (within branch)
+ `type == 1`: branchpoint --> parent-compartment
+ `type == 2`: branchpoint --> child-compartment
+ `type == 3`: parent-compartment --> branchpoint
+ `type == 4`: child-compartment --> branchpoint
"""
- self._internal_node_inds = jnp.asarray([0])
self._comp_edges = pd.DataFrame().from_dict(
{"source": [], "sink": [], "type": []}
)
@@ -89,18 +92,10 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr
- def _init_conds_custom_spsolve(self, params):
- return {
- "branchpoint_conds_children": jnp.asarray([]),
- "branchpoint_conds_parents": jnp.asarray([]),
- "branchpoint_weights_children": jnp.asarray([]),
- "branchpoint_weights_parents": jnp.asarray([]),
- "branch_uppers": jnp.asarray([]),
- "branch_lowers": jnp.asarray([]),
- "branch_diags": jnp.asarray([0.0]),
- }
-
- def _init_conds_jax_spsolve(self, params):
+ def init_conds(self, params: Dict[str, jnp.ndarray]):
+ """Override `Base.init_axial_conds()`.
+
+ This is because compartments do not have any axial conductances."""
return {"axial_conductances": jnp.asarray([])}
diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py
index d4b0eaca..707777ef 100644
--- a/jaxley/modules/network.py
+++ b/jaxley/modules/network.py
@@ -17,16 +17,11 @@
from jaxley.modules.cell import Cell, CellView
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
- comp_edges_to_indices,
- compute_axial_conductances,
compute_children_and_parents,
- compute_coupling_cond,
- compute_coupling_cond_branchpoint,
- compute_impact_on_node,
convert_point_process_to_distributed,
merge_cells,
- remap_to_consecutive,
)
+from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked
from jaxley.utils.syn_utils import gather_synapes
@@ -98,6 +93,7 @@ def __init__(
self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = (
compute_children_and_parents(self.branch_edges)
)
+ self._internal_node_inds = np.arange(self.cumsum_nseg[-1])
# `nbranchpoints` in each cell == cell.par_inds (because `par_inds` are unique).
nbranchpoints = jnp.asarray([len(cell.par_inds) for cell in self.cells])
@@ -135,7 +131,7 @@ def __getattr__(self, key: str):
else:
raise KeyError(f"Key {key} not recognized.")
- def _init_morph_custom_spsolve(self):
+ def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = build_branchpoint_group_inds(
len(self.par_inds),
self.child_belongs_to_branchpoint,
@@ -155,19 +151,34 @@ def _init_morph_custom_spsolve(self):
)
self.root_inds = self.cumsum_nbranches[:-1]
+ # Generate mapping to dealing with the masking which allows using the custom
+ # sparse solver to deal with different nseg per branch.
+ self._remapped_node_indices = remap_index_to_masked(
+ self._internal_node_inds,
+ self.nodes,
+ self.nseg,
+ self.nseg_per_branch,
+ )
+
def _init_morph_jax_spsolve(self):
"""Initialize the morphology for networks.
- The reason that this is a bit involved is that Jaxley considers branchpoint
- nodes to be at the very end of __all__ nodes (i.e. the branchpoints of the
- first cell are even after the compartments of the second cell. The reason for
- this is that, otherwise, `cumsum_nseg` becomes tricky).
+ The reason that this function is a bit involved for a `Network` is that Jaxley
+ considers branchpoint nodes to be at the very end of __all__ nodes (i.e. the
+ branchpoints of the first cell are even after the compartments of the second
+ cell. The reason for this is that, otherwise, `cumsum_nseg` becomes tricky).
To achieve this, we first loop over all compartments and append them, and then
loop over all branchpoints and append those. The code for building the indices
from the `comp_edges` is identical to `jx.Cell`.
+
+ Explanation of `self._comp_eges['type']`:
+ `type == 0`: compartment <--> compartment (within branch)
+ `type == 1`: branchpoint --> parent-compartment
+ `type == 2`: branchpoint --> child-compartment
+ `type == 3`: parent-compartment --> branchpoint
+ `type == 4`: child-compartment --> branchpoint
"""
- self._internal_node_inds = np.arange(self.cumsum_nseg[-1])
self._cumsum_nseg_per_cell = jnp.concatenate(
[
jnp.asarray([0]),
@@ -191,7 +202,7 @@ def _init_morph_jax_spsolve(self):
self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, self.cells
):
offset_within_cell = cell.cumsum_nseg[-1]
- condition = cell._comp_edges["type"].to_numpy() == 1
+ condition = np.isin(cell._comp_edges["type"].to_numpy(), [1, 2])
rows = cell._comp_edges[condition]
self._comp_edges = pd.concat(
[
@@ -211,7 +222,7 @@ def _init_morph_jax_spsolve(self):
self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, self.cells
):
offset_within_cell = cell.cumsum_nseg[-1]
- condition = cell._comp_edges["type"].to_numpy() == 2
+ condition = np.isin(cell._comp_edges["type"].to_numpy(), [3, 4])
rows = cell._comp_edges[condition]
self._comp_edges = pd.concat(
[
@@ -236,78 +247,6 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr
- def _init_conds_custom_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
- """Given an axial resisitivity, set the coupling conductances."""
- nbranches = self.total_nbranches
- nseg = self.nseg
- parents = self.comb_parents
-
- axial_resistivity = jnp.reshape(params["axial_resistivity"], (nbranches, nseg))
- radiuses = jnp.reshape(params["radius"], (nbranches, nseg))
- lengths = jnp.reshape(params["length"], (nbranches, nseg))
-
- conds = vmap(Branch.init_branch_conds_custom_spsolve, in_axes=(0, 0, 0, None))(
- axial_resistivity, radiuses, lengths, self.nseg
- )
- coupling_conds_fwd = conds[0]
- coupling_conds_bwd = conds[1]
- summed_coupling_conds = conds[2]
-
- # The conductance from the children to the branch point.
- branchpoint_conds_children = vmap(
- compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)
- )(
- radiuses[self.child_inds, 0],
- axial_resistivity[self.child_inds, 0],
- lengths[self.child_inds, 0],
- )
- # The conductance from the parents to the branch point.
- branchpoint_conds_parents = vmap(
- compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)
- )(
- radiuses[self.par_inds, -1],
- axial_resistivity[self.par_inds, -1],
- lengths[self.par_inds, -1],
- )
-
- # Weights with which the compartments influence their nearby node.
- # The impact of the children on the branch point.
- branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(
- radiuses[self.child_inds, 0],
- axial_resistivity[self.child_inds, 0],
- lengths[self.child_inds, 0],
- )
- # The impact of parents on the branch point.
- branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(
- radiuses[self.par_inds, -1],
- axial_resistivity[self.par_inds, -1],
- lengths[self.par_inds, -1],
- )
-
- summed_coupling_conds = Cell.update_summed_coupling_conds_custom_spsolve(
- summed_coupling_conds,
- self.child_inds,
- self.par_inds,
- branchpoint_conds_children,
- branchpoint_conds_parents,
- )
-
- cond_params = {
- "branch_uppers": coupling_conds_bwd,
- "branch_lowers": coupling_conds_fwd,
- "branch_diags": summed_coupling_conds,
- "branchpoint_conds_children": branchpoint_conds_children,
- "branchpoint_conds_parents": branchpoint_conds_parents,
- "branchpoint_weights_children": branchpoint_weights_children,
- "branchpoint_weights_parents": branchpoint_weights_parents,
- }
- return cond_params
-
- def _init_conds_jax_spsolve(self, params: Dict):
- """Given length, radius, and r_a, set the coupling conductances."""
- conds = compute_axial_conductances(self._comp_edges, params)
- return {"axial_conductances": conds}
-
def init_syns(self):
"""Initialize synapses."""
self.synapses = []
diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py
index 7ab32a4f..80c1538a 100644
--- a/jaxley/solver_voltage.py
+++ b/jaxley/solver_voltage.py
@@ -4,12 +4,12 @@
from typing import List
import jax.numpy as jnp
+import numpy as np
from jax import vmap
from jax.experimental.sparse.linalg import spsolve as jax_spsolve
from tridiax.stone import stone_backsub_lower, stone_triang_upper
from tridiax.thomas import thomas_backsub_lower, thomas_triang_upper
-from jaxley.build_branched_tridiag import define_all_tridiags
from jaxley.utils.cell_utils import group_and_sum
@@ -17,13 +17,14 @@ def step_voltage_explicit(
voltages: jnp.ndarray,
voltage_terms: jnp.ndarray,
constant_terms: jnp.ndarray,
- coupling_conds_upper: jnp.ndarray,
- coupling_conds_lower: jnp.ndarray,
- summed_coupling_conds: jnp.ndarray,
- branchpoint_conds_children: jnp.ndarray,
- branchpoint_conds_parents: jnp.ndarray,
- branchpoint_weights_children: jnp.ndarray,
- branchpoint_weights_parents: jnp.ndarray,
+ axial_conductances: jnp.ndarray,
+ internal_node_inds: jnp.ndarray,
+ sinks: jnp.ndarray,
+ sources: jnp.ndarray,
+ types: jnp.ndarray,
+ masked_node_inds: jnp.ndarray,
+ nseg_per_branch: jnp.ndarray,
+ nseg: int,
par_inds: jnp.ndarray,
child_inds: jnp.ndarray,
nbranches: int,
@@ -40,17 +41,14 @@ def step_voltage_explicit(
voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))
constant_terms = jnp.reshape(constant_terms, (nbranches, -1))
- update = voltage_vectorfield(
+ update = _voltage_vectorfield(
voltages,
voltage_terms,
constant_terms,
- coupling_conds_upper,
- coupling_conds_lower,
- summed_coupling_conds,
- branchpoint_conds_children,
- branchpoint_conds_parents,
- branchpoint_weights_children,
- branchpoint_weights_parents,
+ types,
+ sources,
+ sinks,
+ axial_conductances,
par_inds,
child_inds,
nbranches,
@@ -66,17 +64,18 @@ def step_voltage_explicit(
return new_voltates.ravel(order="C")
-def step_voltage_implicit_with_custom_spsolve(
+def step_voltage_implicit_with_jaxley_spsolve(
voltages: jnp.ndarray,
voltage_terms: jnp.ndarray,
constant_terms: jnp.ndarray,
- coupling_conds_upper: jnp.ndarray,
- coupling_conds_lower: jnp.ndarray,
- summed_coupling_conds: jnp.ndarray,
- branchpoint_conds_children: jnp.ndarray,
- branchpoint_conds_parents: jnp.ndarray,
- branchpoint_weights_children: jnp.ndarray,
- branchpoint_weights_parents: jnp.ndarray,
+ axial_conductances: jnp.ndarray,
+ internal_node_inds: jnp.ndarray,
+ sinks: jnp.ndarray,
+ sources: jnp.ndarray,
+ types: jnp.ndarray,
+ masked_node_inds: jnp.ndarray,
+ nseg_per_branch: jnp.ndarray,
+ nseg: int,
par_inds: jnp.ndarray,
child_inds: jnp.ndarray,
nbranches: int,
@@ -89,24 +88,60 @@ def step_voltage_implicit_with_custom_spsolve(
debug_states,
):
"""Solve one timestep of branched nerve equations with implicit (backward) Euler."""
- voltages = jnp.reshape(voltages, (nbranches, -1))
- voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))
- constant_terms = jnp.reshape(constant_terms, (nbranches, -1))
- coupling_conds_upper = jnp.reshape(coupling_conds_upper, (nbranches, -1))
- coupling_conds_lower = jnp.reshape(coupling_conds_lower, (nbranches, -1))
- summed_coupling_conds = jnp.reshape(summed_coupling_conds, (nbranches, -1))
+ # Build diagonals.
+ c2c = np.isin(types, [0, 1, 2])
+ diags = jnp.ones(nbranches * nseg)
- # Define quasi-tridiagonal system.
- lowers, diags, uppers, solves = define_all_tridiags(
- voltages,
- voltage_terms,
- constant_terms,
- nbranches,
- coupling_conds_upper,
- coupling_conds_lower,
- summed_coupling_conds,
- delta_t,
+ # if-case needed because `.at` does not allow empty inputs, but the input is
+ # empty for compartments.
+ if len(sinks[c2c]) > 0:
+ diags = diags.at[masked_node_inds[sinks[c2c]]].add(
+ delta_t * axial_conductances[c2c]
+ )
+
+ diags = diags.at[masked_node_inds[internal_node_inds]].add(delta_t * voltage_terms)
+
+ # Build solves.
+ solves = jnp.zeros(nbranches * nseg)
+ solves = solves.at[masked_node_inds[internal_node_inds]].add(
+ voltages + delta_t * constant_terms
)
+
+ # Build upper and lower within the branch.
+ c2c = types == 0 # c2c = compartment-to-compartment.
+
+ # Build uppers.
+ uppers = jnp.zeros(nbranches * nseg)
+ upper_inds = sources[c2c] > sinks[c2c]
+ sinks_upper = sinks[c2c][upper_inds]
+ if len(sinks_upper) > 0:
+ uppers = uppers.at[masked_node_inds[sinks_upper]].add(
+ -delta_t * axial_conductances[c2c][upper_inds]
+ )
+
+ # Build lowers.
+ lowers = jnp.zeros(nbranches * nseg)
+ lower_inds = sources[c2c] < sinks[c2c]
+ sinks_lower = sinks[c2c][lower_inds]
+ if len(sinks_lower) > 0:
+ lowers = lowers.at[masked_node_inds[sinks_lower]].add(
+ -delta_t * axial_conductances[c2c][lower_inds]
+ )
+
+ # Reshape all diags, lowers, uppers, and solves into a "per-branch" format.
+ diags = jnp.reshape(diags, (nbranches, -1))
+ solves = jnp.reshape(solves, (nbranches, -1))
+ uppers = jnp.reshape(uppers, (nbranches, -1))
+ lowers = jnp.reshape(lowers, (nbranches, -1))
+ # lowers and uppers were built to have length `nseg` above for simplicity.
+ uppers = uppers[:, :-1]
+ lowers = lowers[:, 1:]
+
+ # Build branchpoint conductances.
+ branchpoint_conds_parents = axial_conductances[types == 1]
+ branchpoint_conds_children = axial_conductances[types == 2]
+ branchpoint_weights_parents = axial_conductances[types == 3]
+ branchpoint_weights_children = axial_conductances[types == 4]
all_branchpoint_vals = jnp.concatenate(
[branchpoint_weights_parents, branchpoint_weights_children]
)
@@ -169,6 +204,7 @@ def step_voltage_implicit_with_custom_spsolve(
children_in_level,
parents_in_level,
root_inds,
+ nseg_per_branch,
debug_states,
)
@@ -195,9 +231,10 @@ def step_voltage_implicit_with_custom_spsolve(
children_in_level,
parents_in_level,
root_inds,
+ nseg_per_branch,
debug_states,
)
- return solves.ravel(order="C")
+ return solves.ravel(order="C")[masked_node_inds[internal_node_inds]]
def step_voltage_implicit_with_jax_spsolve(
@@ -213,24 +250,22 @@ def step_voltage_implicit_with_jax_spsolve(
n_nodes,
internal_node_inds,
):
+ axial_conductances = delta_t * axial_conductances
+
# Build diagonals.
diagonal_values = jnp.zeros(n_nodes)
# if-case needed because `.at` does not allow empty inputs, but the input is
# empty for compartments.
if len(sinks) > 0:
- diagonal_values = diagonal_values.at[sinks].add(delta_t * axial_conductances)
+ diagonal_values = diagonal_values.at[sinks].add(axial_conductances)
diagonal_values = diagonal_values.at[internal_node_inds].add(
- delta_t * voltage_terms
+ 1.0 + delta_t * voltage_terms
)
- diagonal_values = diagonal_values.at[internal_node_inds].add(1.0)
- # Build off-diagonals.
- axial_conductances = -delta_t * axial_conductances
-
- # Concatenate diagonals and off-diagonals.
- all_values = jnp.concatenate([diagonal_values, axial_conductances])
+ # Concatenate diagonals and off-diagonals (which are just `-axial_conductances`).
+ all_values = jnp.concatenate([diagonal_values, -axial_conductances])
# Build solve.
solves = jnp.zeros(n_nodes)
@@ -243,45 +278,58 @@ def step_voltage_implicit_with_jax_spsolve(
return solution
-def voltage_vectorfield(
- voltages,
- voltage_terms,
- constant_terms,
- coupling_conds_upper,
- coupling_conds_lower,
- summed_coupling_conds,
- branchpoint_conds_children,
- branchpoint_conds_parents,
- branchpoint_weights_children,
- branchpoint_weights_parents,
- par_inds,
- child_inds,
- nbranches,
- solver,
- delta_t,
- children_in_level,
- parents_in_level,
- root_inds,
- branchpoint_group_inds,
+def _voltage_vectorfield(
+ voltages: jnp.ndarray,
+ voltage_terms: jnp.ndarray,
+ constant_terms: jnp.ndarray,
+ types: jnp.ndarray,
+ sources: jnp.ndarray,
+ sinks: jnp.ndarray,
+ axial_conductances: jnp.ndarray,
+ par_inds: jnp.ndarray,
+ child_inds: jnp.ndarray,
+ nbranches: int,
+ solver: str,
+ delta_t: float,
+ children_in_level: List[jnp.ndarray],
+ parents_in_level: List[jnp.ndarray],
+ root_inds: jnp.ndarray,
+ branchpoint_group_inds: jnp.ndarray,
debug_states,
) -> jnp.ndarray:
"""Evaluate the vectorfield of the nerve equation."""
+ if np.sum(np.isin(types, [1, 2, 3, 4])) > 0:
+ raise NotImplementedError(
+ f"Forward Euler is not implemented for branched morphologies."
+ )
+
# Membrane current update.
vecfield = -voltage_terms * voltages + constant_terms
- # Current through segments within the same branch.
- vecfield = vecfield.at[:, :-1].add(
- (voltages[:, 1:] - voltages[:, :-1]) * coupling_conds_upper
- )
- vecfield = vecfield.at[:, 1:].add(
- (voltages[:, :-1] - voltages[:, 1:]) * coupling_conds_lower
- )
+ # Build upper and lower within the branch.
+ c2c = types == 0 # c2c = compartment-to-compartment.
- # Current through branch points.
- if len(branchpoint_conds_children) > 0:
- raise NotImplementedError(
- f"Forward Euler is not implemented for branched morphologies."
- )
+ # Build uppers.
+ upper_inds = sources[c2c] > sinks[c2c]
+ if len(upper_inds) > 0:
+ uppers = axial_conductances[c2c][upper_inds]
+ else:
+ uppers = jnp.asarray([])
+
+ # Build lowers.
+ lower_inds = sources[c2c] < sinks[c2c]
+ if len(lower_inds) > 0:
+ lowers = axial_conductances[c2c][lower_inds]
+ else:
+ lowers = jnp.asarray([])
+
+ # For networks consisting of branches.
+ uppers = jnp.reshape(uppers, (nbranches, -1))
+ lowers = jnp.reshape(lowers, (nbranches, -1))
+
+ # Current through segments within the same branch.
+ vecfield = vecfield.at[:, :-1].add((voltages[:, 1:] - voltages[:, :-1]) * uppers)
+ vecfield = vecfield.at[:, 1:].add((voltages[:, :-1] - voltages[:, 1:]) * lowers)
return vecfield
@@ -301,6 +349,7 @@ def _triang_branched(
children_in_level,
parents_in_level,
root_inds,
+ nseg_per_branch,
debug_states,
):
"""Triangulation."""
@@ -329,6 +378,7 @@ def _triang_branched(
branchpoint_weights_parents,
branchpoint_diags,
branchpoint_solves,
+ nseg_per_branch,
)
# At last level, we do not want to eliminate anymore.
diags, lowers, solves, uppers = _triang_level(
@@ -361,6 +411,7 @@ def _backsub_branched(
children_in_level,
parents_in_level,
root_inds,
+ nseg_per_branch,
debug_states,
):
"""
@@ -378,6 +429,7 @@ def _backsub_branched(
solves,
branchpoint_weights_parents,
branchpoint_solves,
+ nseg_per_branch,
)
branchpoint_conds_children, solves = _eliminate_children_upper(
cil,
@@ -484,6 +536,7 @@ def _eliminate_parents_upper(
branchpoint_weights_parents,
branchpoint_diags,
branchpoint_solves,
+ nseg_per_branch: jnp.ndarray,
):
bil = pil[:, 0]
bpil = pil[:, 1]
@@ -495,8 +548,8 @@ def _eliminate_parents_upper(
)
# Update the diagonal elements and `b` in `Ax=b` (called `solves`).
- diags = diags.at[bil, -1].add(new_diag)
- solves = solves.at[bil, -1].add(new_solve)
+ diags = diags.at[bil, nseg_per_branch[bil] - 1].add(new_diag)
+ solves = solves.at[bil, nseg_per_branch[bil] - 1].add(new_solve)
branchpoint_conds_parents = branchpoint_conds_parents.at[bil].set(0.0)
return diags, solves, branchpoint_conds_parents
@@ -521,11 +574,14 @@ def _eliminate_parents_lower(
solves,
branchpoint_weights_parents,
branchpoint_solves,
+ nseg_per_branch: jnp.ndarray,
):
bil = pil[:, 0]
bpil = pil[:, 1]
branchpoint_solves = branchpoint_solves.at[bpil].add(
- -solves[bil, -1] * branchpoint_weights_parents[bil] / diags[bil, -1]
+ -solves[bil, nseg_per_branch[bil] - 1]
+ * branchpoint_weights_parents[bil]
+ / diags[bil, nseg_per_branch[bil] - 1]
)
branchpoint_weights_parents = branchpoint_weights_parents.at[bil].set(0.0)
return branchpoint_weights_parents, branchpoint_solves
diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py
index 136093b4..c35e3c4e 100644
--- a/jaxley/utils/cell_utils.py
+++ b/jaxley/utils/cell_utils.py
@@ -409,95 +409,29 @@ def query_channel_states_and_params(d, keys, idcs):
return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))
-def convert_to_csc(
- num_elements: int, row_ind: np.ndarray, col_ind: np.ndarray
-) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
- """Convert between two representations for sparse systems.
-
- This is needed because `jax.scipy.linalg.spsolve` requires the `(ind, indptr)`
- representation, but the `(row, col)` is more intuitive and easier to build.
-
- This function uses `np` instead of `jnp` because it only deals with indexing which
- can be dealt with only based on the branch structure (i.e. independent of any
- parameter values).
-
- Written by ChatGPT.
- """
- data_inds = np.arange(num_elements)
- # Step 1: Sort by (col_ind, row_ind)
- sorted_indices = np.lexsort((row_ind, col_ind))
- data_inds = data_inds[sorted_indices]
- row_ind = row_ind[sorted_indices]
- col_ind = col_ind[sorted_indices]
-
- # Step 2: Create indptr array
- n_cols = col_ind.max() + 1
- indptr = np.zeros(n_cols + 1, dtype=int)
- np.add.at(indptr, col_ind + 1, 1)
- np.cumsum(indptr, out=indptr)
-
- # Step 3: The row indices are already sorted
- indices = row_ind
-
- return data_inds, indices, indptr
-
-
-def comp_edges_to_indices(
- comp_edges: pd.DataFrame,
-) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
- """Generates sparse matrix indices from the table of node edges.
-
- This is only used for the `jax.sparse` voltage solver.
-
- Args:
- comp_edges: Dataframe with three columns (sink, source, type).
-
- Returns:
- n_nodes: The number of total nodes (including branchpoints).
- data_inds: The indices to reorder the data.
- indices and indptr: Indices passed to the sparse matrix solver.
- """
- # Build indices for diagonals.
- sources = np.asarray(comp_edges["source"].to_list())
- sinks = np.asarray(comp_edges["sink"].to_list())
- n_nodes = np.max(sinks) + 1 if len(sinks) > 0 else 1
- diagonal_inds = jnp.stack([jnp.arange(n_nodes), jnp.arange(n_nodes)])
-
- # Build indices for off-diagonals.
- off_diagonal_inds = jnp.stack([sources, sinks]).astype(int)
-
- # Concatenate indices of diagonals and off-diagonals.
- all_inds = jnp.concatenate([diagonal_inds, off_diagonal_inds], axis=1)
-
- # Cast (row, col) indices to the format required for the `jax` sparse solver.
- data_inds, indices, indptr = convert_to_csc(
- num_elements=all_inds.shape[1],
- row_ind=all_inds[0],
- col_ind=all_inds[1],
- )
- return n_nodes, data_inds, indices, indptr
-
-
def compute_axial_conductances(
comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]
) -> jnp.ndarray:
"""Given `comp_edges`, radius, length, r_a, compute the axial conductances."""
- # `Compartment-to-compartment` (c2c) conductances.
+ # `Compartment-to-compartment` (c2c) axial coupling conductances.
condition = comp_edges["type"].to_numpy() == 0
source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())
- conds_c2c = vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(
- params["radius"][sink_comp_inds],
- params["radius"][source_comp_inds],
- params["axial_resistivity"][sink_comp_inds],
- params["axial_resistivity"][source_comp_inds],
- params["length"][sink_comp_inds],
- params["length"][source_comp_inds],
- )
+ if len(sink_comp_inds) > 0:
+ conds_c2c = vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(
+ params["radius"][sink_comp_inds],
+ params["radius"][source_comp_inds],
+ params["axial_resistivity"][sink_comp_inds],
+ params["axial_resistivity"][source_comp_inds],
+ params["length"][sink_comp_inds],
+ params["length"][source_comp_inds],
+ )
+ else:
+ conds_c2c = jnp.asarray([])
- # `branchpoint-to-compartment` (bp2c) conductances.
- condition = comp_edges["type"].to_numpy() == 1
+ # `branchpoint-to-compartment` (bp2c) axial coupling conductances.
+ condition = np.isin(comp_edges["type"].to_numpy(), [1, 2])
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())
if len(sink_comp_inds) > 0:
@@ -509,8 +443,8 @@ def compute_axial_conductances(
else:
conds_bp2c = jnp.asarray([])
- # `compartment-to-branchpoint` (c2bp) conductances.
- condition = comp_edges["type"].to_numpy() == 2
+ # `compartment-to-branchpoint` (c2bp) axial coupling conductances.
+ condition = np.isin(comp_edges["type"].to_numpy(), [3, 4])
source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())
if len(source_comp_inds) > 0:
@@ -525,7 +459,7 @@ def compute_axial_conductances(
else:
conds_c2bp = jnp.asarray([])
- # All conductances.
+ # All axial coupling conductances.
return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])
diff --git a/jaxley/utils/solver_utils.py b/jaxley/utils/solver_utils.py
new file mode 100644
index 00000000..36759e42
--- /dev/null
+++ b/jaxley/utils/solver_utils.py
@@ -0,0 +1,96 @@
+# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
+# licensed under the Apache License Version 2.0, see
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+
+
+def remap_index_to_masked(
+ index, nodes: pd.DataFrame, max_nseg: int, nseg_per_branch: jnp.ndarray
+):
+ """Convert actual index of the compartment to the index in the masked system.
+
+ E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the
+ masked `nsegs` are `[4, 4]`.
+ """
+ cumsum_nseg_per_branch = jnp.concatenate(
+ [
+ jnp.asarray([0]),
+ jnp.cumsum(nseg_per_branch),
+ ]
+ )
+ branch_inds = nodes.loc[index, "branch_index"].to_numpy()
+ remainders = index - cumsum_nseg_per_branch[branch_inds]
+ return branch_inds * max_nseg + remainders
+
+
+def convert_to_csc(
+ num_elements: int, row_ind: np.ndarray, col_ind: np.ndarray
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Convert between two representations for sparse systems.
+
+ This is needed because `jax.scipy.linalg.spsolve` requires the `(ind, indptr)`
+ representation, but the `(row, col)` is more intuitive and easier to build.
+
+ This function uses `np` instead of `jnp` because it only deals with indexing which
+ can be dealt with only based on the branch structure (i.e. independent of any
+ parameter values).
+
+ Written by ChatGPT.
+ """
+ data_inds = np.arange(num_elements)
+ # Step 1: Sort by (col_ind, row_ind)
+ sorted_indices = np.lexsort((row_ind, col_ind))
+ data_inds = data_inds[sorted_indices]
+ row_ind = row_ind[sorted_indices]
+ col_ind = col_ind[sorted_indices]
+
+ # Step 2: Create indptr array
+ n_cols = col_ind.max() + 1
+ indptr = np.zeros(n_cols + 1, dtype=int)
+ np.add.at(indptr, col_ind + 1, 1)
+ np.cumsum(indptr, out=indptr)
+
+ # Step 3: The row indices are already sorted
+ indices = row_ind
+
+ return data_inds, indices, indptr
+
+
+def comp_edges_to_indices(
+ comp_edges: pd.DataFrame,
+) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
+ """Generates sparse matrix indices from the table of node edges.
+
+ This is only used for the `jax.sparse` voltage solver.
+
+ Args:
+ comp_edges: Dataframe with three columns (sink, source, type).
+
+ Returns:
+ n_nodes: The number of total nodes (including branchpoints).
+ data_inds: The indices to reorder the data.
+ indices and indptr: Indices passed to the sparse matrix solver.
+ """
+ # Build indices for diagonals.
+ sources = np.asarray(comp_edges["source"].to_list())
+ sinks = np.asarray(comp_edges["sink"].to_list())
+ n_nodes = np.max(sinks) + 1 if len(sinks) > 0 else 1
+ diagonal_inds = jnp.stack([jnp.arange(n_nodes), jnp.arange(n_nodes)])
+
+ # Build indices for off-diagonals.
+ off_diagonal_inds = jnp.stack([sources, sinks]).astype(int)
+
+ # Concatenate indices of diagonals and off-diagonals.
+ all_inds = jnp.concatenate([diagonal_inds, off_diagonal_inds], axis=1)
+
+ # Cast (row, col) indices to the format required for the `jax` sparse solver.
+ data_inds, indices, indptr = convert_to_csc(
+ num_elements=all_inds.shape[1],
+ row_ind=all_inds[0],
+ col_ind=all_inds[1],
+ )
+ return n_nodes, data_inds, indices, indptr
diff --git a/tests/jaxley_identical/test_basic_modules.py b/tests/jaxley_identical/test_basic_modules.py
index c78d9b23..ffa44072 100644
--- a/tests/jaxley_identical/test_basic_modules.py
+++ b/tests/jaxley_identical/test_basic_modules.py
@@ -27,13 +27,7 @@ def test_compartment(voltage_solver: str):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.02, dt, t_max)
-
- comp = jx.Compartment()
- comp.insert(HH())
- comp.record()
- comp.stimulate(current)
-
- voltages = jx.integrate(comp, delta_t=dt, voltage_solver=voltage_solver)
+ tolerance = 1e-8
voltages_081123 = jnp.asarray(
[
@@ -52,9 +46,43 @@ def test_compartment(voltage_solver: str):
]
]
)
+
+ # Test compartment.
+ comp = jx.Compartment()
+ comp.insert(HH())
+ comp.record()
+ comp.stimulate(current)
+ voltages = jx.integrate(comp, delta_t=dt, voltage_solver=voltage_solver)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
- tolerance = 1e-8
- assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
+ assert max_error <= tolerance, f"Compartment error is {max_error} > {tolerance}"
+
+ # Test branch of a single compartment.
+ branch = jx.Branch()
+ branch.insert(HH())
+ branch.record()
+ branch.stimulate(current)
+ voltages = jx.integrate(branch, delta_t=dt, voltage_solver=voltage_solver)
+ max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
+ assert max_error <= tolerance, f"Branch error is {max_error} > {tolerance}"
+
+ # Test cell of a single compartment.
+ cell = jx.Cell()
+ cell.insert(HH())
+ cell.record()
+ cell.stimulate(current)
+ voltages = jx.integrate(cell, delta_t=dt, voltage_solver=voltage_solver)
+ max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
+ assert max_error <= tolerance, f"Cell error is {max_error} > {tolerance}"
+
+ # Test net of a single compartment.
+ cell = jx.Cell()
+ net = jx.Network([cell])
+ net.insert(HH())
+ net.record()
+ net.stimulate(current)
+ voltages = jx.integrate(net, delta_t=dt, voltage_solver=voltage_solver)
+ max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
+ assert max_error <= tolerance, f"Network error is {max_error} > {tolerance}"
@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
@@ -94,6 +122,44 @@ def test_branch(voltage_solver: str):
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
+def test_branch_fwd_euler_uneven_radiuses():
+ dt = 0.025 # ms
+ t_max = 10.0 # ms
+ current = jx.step_current(0.5, 1.0, 2.0, dt, t_max)
+
+ comp = jx.Compartment()
+ branch = jx.Branch(comp, 8)
+ branch.set("axial_resistivity", 500.0)
+
+ rands1 = np.linspace(20, 300, 8)
+ rands2 = np.linspace(1, 5, 8)
+ branch.set("length", rands1)
+ branch.set("radius", rands2)
+
+ branch.insert(HH())
+ branch.loc(1.0).stimulate(current)
+ branch.loc(0.0).record()
+
+ voltages = jx.integrate(branch, delta_t=dt, solver="fwd_euler")
+
+ voltages_240920 = jnp.asarray(
+ [
+ -70.0,
+ -64.319374,
+ -61.61975,
+ -56.971237,
+ 25.785686,
+ -42.466354,
+ -75.86178,
+ -75.06558,
+ -73.95041,
+ ]
+ )
+ tolerance = 1e-5
+ max_error = jnp.max(jnp.abs(voltages_240920 - voltages[0, ::50]))
+ assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
+
+
@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
def test_cell(voltage_solver: str):
nseg_per_branch = 2
diff --git a/tests/jaxley_identical/test_swc.py b/tests/jaxley_identical/test_swc.py
index 959ba79c..24e49a44 100644
--- a/tests/jaxley_identical/test_swc.py
+++ b/tests/jaxley_identical/test_swc.py
@@ -21,8 +21,9 @@
from jaxley.synapses import IonotropicSynapse
+@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
@pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"])
-def test_swc_cell(file):
+def test_swc_cell(voltage_solver: str, file: str):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.2, dt, t_max)
@@ -35,7 +36,7 @@ def test_swc_cell(file):
cell.branch(1).loc(0.0).record()
cell.branch(1).loc(0.0).stimulate(current)
- voltages = jx.integrate(cell, delta_t=dt)
+ voltages = jx.integrate(cell, delta_t=dt, voltage_solver=voltage_solver)
if file == "morph_single_point_soma.swc":
voltages_300724 = jnp.asarray(
@@ -80,7 +81,8 @@ def test_swc_cell(file):
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
-def test_swc_net():
+@pytest.mark.parametrize("voltage_solver", ["jaxley.stone", "jax.sparse"])
+def test_swc_net(voltage_solver: str):
dt = 0.025 # ms
t_max = 5.0 # ms
current = jx.step_current(0.5, 1.0, 0.2, dt, t_max)
@@ -111,7 +113,7 @@ def test_swc_net():
for stim_ind in range(2):
network.cell(stim_ind).branch(1).loc(0.0).stimulate(current)
- voltages = jx.integrate(network, delta_t=dt)
+ voltages = jx.integrate(network, delta_t=dt, voltage_solver=voltage_solver)
voltages_300724 = jnp.asarray(
[