diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index fe461782..42da6653 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -26,7 +26,6 @@ compute_axial_conductances, compute_levels, convert_point_process_to_distributed, - convert_to_csc, interpolate_xyz, loc_of_index, query_channel_states_and_params, @@ -35,6 +34,7 @@ from jaxley.utils.debug_solver import compute_morphology_indices from jaxley.utils.misc_utils import childview, concat_and_ignore_empty from jaxley.utils.plot_utils import plot_morph +from jaxley.utils.solver_utils import convert_to_csc class Module(ABC): diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index 0eb23bca..3933c929 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -10,7 +10,8 @@ from jaxley.modules.base import GroupView, Module, View from jaxley.modules.compartment import Compartment, CompartmentView -from jaxley.utils.cell_utils import comp_edges_to_indices, compute_children_and_parents +from jaxley.utils.cell_utils import compute_children_and_parents +from jaxley.utils.solver_utils import comp_edges_to_indices class Branch(Module): diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 391bdef6..d8aa513b 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -13,15 +13,14 @@ from jaxley.synapses import Synapse from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, - comp_edges_to_indices, compute_children_and_parents, compute_children_in_level, compute_children_indices, compute_levels, compute_morphology_indices_in_levels, compute_parents_in_level, - remap_index_to_masked, ) +from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked from jaxley.utils.swc import swc_to_jaxley diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index c687888f..9a3571c3 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -10,12 +10,12 @@ from jaxley.modules.base import Module, View from jaxley.utils.cell_utils import ( - comp_edges_to_indices, compute_children_and_parents, index_of_loc, interpolate_xyz, loc_of_index, ) +from jaxley.utils.solver_utils import comp_edges_to_indices class Compartment(Module): diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 47b36489..707777ef 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -17,12 +17,11 @@ from jaxley.modules.cell import Cell, CellView from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, - comp_edges_to_indices, compute_children_and_parents, convert_point_process_to_distributed, merge_cells, - remap_index_to_masked, ) +from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked from jaxley.utils.syn_utils import gather_synapes diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index 5bfff2b3..c35e3c4e 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -409,80 +409,11 @@ def query_channel_states_and_params(d, keys, idcs): return dict(zip(keys, (v[idcs] for v in map(d.get, keys)))) -def convert_to_csc( - num_elements: int, row_ind: np.ndarray, col_ind: np.ndarray -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Convert between two representations for sparse systems. - - This is needed because `jax.scipy.linalg.spsolve` requires the `(ind, indptr)` - representation, but the `(row, col)` is more intuitive and easier to build. - - This function uses `np` instead of `jnp` because it only deals with indexing which - can be dealt with only based on the branch structure (i.e. independent of any - parameter values). - - Written by ChatGPT. - """ - data_inds = np.arange(num_elements) - # Step 1: Sort by (col_ind, row_ind) - sorted_indices = np.lexsort((row_ind, col_ind)) - data_inds = data_inds[sorted_indices] - row_ind = row_ind[sorted_indices] - col_ind = col_ind[sorted_indices] - - # Step 2: Create indptr array - n_cols = col_ind.max() + 1 - indptr = np.zeros(n_cols + 1, dtype=int) - np.add.at(indptr, col_ind + 1, 1) - np.cumsum(indptr, out=indptr) - - # Step 3: The row indices are already sorted - indices = row_ind - - return data_inds, indices, indptr - - -def comp_edges_to_indices( - comp_edges: pd.DataFrame, -) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """Generates sparse matrix indices from the table of node edges. - - This is only used for the `jax.sparse` voltage solver. - - Args: - comp_edges: Dataframe with three columns (sink, source, type). - - Returns: - n_nodes: The number of total nodes (including branchpoints). - data_inds: The indices to reorder the data. - indices and indptr: Indices passed to the sparse matrix solver. - """ - # Build indices for diagonals. - sources = np.asarray(comp_edges["source"].to_list()) - sinks = np.asarray(comp_edges["sink"].to_list()) - n_nodes = np.max(sinks) + 1 if len(sinks) > 0 else 1 - diagonal_inds = jnp.stack([jnp.arange(n_nodes), jnp.arange(n_nodes)]) - - # Build indices for off-diagonals. - off_diagonal_inds = jnp.stack([sources, sinks]).astype(int) - - # Concatenate indices of diagonals and off-diagonals. - all_inds = jnp.concatenate([diagonal_inds, off_diagonal_inds], axis=1) - - # Cast (row, col) indices to the format required for the `jax` sparse solver. - data_inds, indices, indptr = convert_to_csc( - num_elements=all_inds.shape[1], - row_ind=all_inds[0], - col_ind=all_inds[1], - ) - return n_nodes, data_inds, indices, indptr - - def compute_axial_conductances( comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray] ) -> jnp.ndarray: """Given `comp_edges`, radius, length, r_a, compute the axial conductances.""" - # `Compartment-to-compartment` (c2c) conductances. + # `Compartment-to-compartment` (c2c) axial coupling conductances. condition = comp_edges["type"].to_numpy() == 0 source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) @@ -499,7 +430,7 @@ def compute_axial_conductances( else: conds_c2c = jnp.asarray([]) - # `branchpoint-to-compartment` (bp2c) conductances. + # `branchpoint-to-compartment` (bp2c) axial coupling conductances. condition = np.isin(comp_edges["type"].to_numpy(), [1, 2]) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) @@ -512,7 +443,7 @@ def compute_axial_conductances( else: conds_bp2c = jnp.asarray([]) - # `compartment-to-branchpoint` (c2bp) conductances. + # `compartment-to-branchpoint` (c2bp) axial coupling conductances. condition = np.isin(comp_edges["type"].to_numpy(), [3, 4]) source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) @@ -528,7 +459,7 @@ def compute_axial_conductances( else: conds_c2bp = jnp.asarray([]) - # All conductances. + # All axial coupling conductances. return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp]) @@ -541,22 +472,3 @@ def compute_children_and_parents( child_belongs_to_branchpoint = remap_to_consecutive(par_inds) par_inds = np.unique(par_inds) return par_inds, child_inds, child_belongs_to_branchpoint - - -def remap_index_to_masked( - index, nodes: pd.DataFrame, max_nseg: int, nseg_per_branch: jnp.ndarray -): - """Convert actual index of the compartment to the index in the masked system. - - E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the - masked `nsegs` are `[4, 4]`. - """ - cumsum_nseg_per_branch = jnp.concatenate( - [ - jnp.asarray([0]), - jnp.cumsum(nseg_per_branch), - ] - ) - branch_inds = nodes.loc[index, "branch_index"].to_numpy() - remainders = index - cumsum_nseg_per_branch[branch_inds] - return branch_inds * max_nseg + remainders diff --git a/jaxley/utils/solver_utils.py b/jaxley/utils/solver_utils.py new file mode 100644 index 00000000..36759e42 --- /dev/null +++ b/jaxley/utils/solver_utils.py @@ -0,0 +1,96 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +from typing import Dict, List, Optional, Tuple, Union + +import jax.numpy as jnp +import numpy as np +import pandas as pd + + +def remap_index_to_masked( + index, nodes: pd.DataFrame, max_nseg: int, nseg_per_branch: jnp.ndarray +): + """Convert actual index of the compartment to the index in the masked system. + + E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the + masked `nsegs` are `[4, 4]`. + """ + cumsum_nseg_per_branch = jnp.concatenate( + [ + jnp.asarray([0]), + jnp.cumsum(nseg_per_branch), + ] + ) + branch_inds = nodes.loc[index, "branch_index"].to_numpy() + remainders = index - cumsum_nseg_per_branch[branch_inds] + return branch_inds * max_nseg + remainders + + +def convert_to_csc( + num_elements: int, row_ind: np.ndarray, col_ind: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Convert between two representations for sparse systems. + + This is needed because `jax.scipy.linalg.spsolve` requires the `(ind, indptr)` + representation, but the `(row, col)` is more intuitive and easier to build. + + This function uses `np` instead of `jnp` because it only deals with indexing which + can be dealt with only based on the branch structure (i.e. independent of any + parameter values). + + Written by ChatGPT. + """ + data_inds = np.arange(num_elements) + # Step 1: Sort by (col_ind, row_ind) + sorted_indices = np.lexsort((row_ind, col_ind)) + data_inds = data_inds[sorted_indices] + row_ind = row_ind[sorted_indices] + col_ind = col_ind[sorted_indices] + + # Step 2: Create indptr array + n_cols = col_ind.max() + 1 + indptr = np.zeros(n_cols + 1, dtype=int) + np.add.at(indptr, col_ind + 1, 1) + np.cumsum(indptr, out=indptr) + + # Step 3: The row indices are already sorted + indices = row_ind + + return data_inds, indices, indptr + + +def comp_edges_to_indices( + comp_edges: pd.DataFrame, +) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Generates sparse matrix indices from the table of node edges. + + This is only used for the `jax.sparse` voltage solver. + + Args: + comp_edges: Dataframe with three columns (sink, source, type). + + Returns: + n_nodes: The number of total nodes (including branchpoints). + data_inds: The indices to reorder the data. + indices and indptr: Indices passed to the sparse matrix solver. + """ + # Build indices for diagonals. + sources = np.asarray(comp_edges["source"].to_list()) + sinks = np.asarray(comp_edges["sink"].to_list()) + n_nodes = np.max(sinks) + 1 if len(sinks) > 0 else 1 + diagonal_inds = jnp.stack([jnp.arange(n_nodes), jnp.arange(n_nodes)]) + + # Build indices for off-diagonals. + off_diagonal_inds = jnp.stack([sources, sinks]).astype(int) + + # Concatenate indices of diagonals and off-diagonals. + all_inds = jnp.concatenate([diagonal_inds, off_diagonal_inds], axis=1) + + # Cast (row, col) indices to the format required for the `jax` sparse solver. + data_inds, indices, indptr = convert_to_csc( + num_elements=all_inds.shape[1], + row_ind=all_inds[0], + col_ind=all_inds[1], + ) + return n_nodes, data_inds, indices, indptr