Skip to content

Commit

Permalink
Move some utilities to solver_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Sep 20, 2024
1 parent 9e782e3 commit 54b45a5
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 99 deletions.
2 changes: 1 addition & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
compute_axial_conductances,
compute_levels,
convert_point_process_to_distributed,
convert_to_csc,
interpolate_xyz,
loc_of_index,
query_channel_states_and_params,
Expand All @@ -35,6 +34,7 @@
from jaxley.utils.debug_solver import compute_morphology_indices
from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
from jaxley.utils.plot_utils import plot_morph
from jaxley.utils.solver_utils import convert_to_csc


class Module(ABC):
Expand Down
3 changes: 2 additions & 1 deletion jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.compartment import Compartment, CompartmentView
from jaxley.utils.cell_utils import comp_edges_to_indices, compute_children_and_parents
from jaxley.utils.cell_utils import compute_children_and_parents
from jaxley.utils.solver_utils import comp_edges_to_indices


class Branch(Module):
Expand Down
3 changes: 1 addition & 2 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from jaxley.synapses import Synapse
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
comp_edges_to_indices,
compute_children_and_parents,
compute_children_in_level,
compute_children_indices,
compute_levels,
compute_morphology_indices_in_levels,
compute_parents_in_level,
remap_index_to_masked,
)
from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked
from jaxley.utils.swc import swc_to_jaxley


Expand Down
2 changes: 1 addition & 1 deletion jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from jaxley.modules.base import Module, View
from jaxley.utils.cell_utils import (
comp_edges_to_indices,
compute_children_and_parents,
index_of_loc,
interpolate_xyz,
loc_of_index,
)
from jaxley.utils.solver_utils import comp_edges_to_indices


class Compartment(Module):
Expand Down
3 changes: 1 addition & 2 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
from jaxley.modules.cell import Cell, CellView
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
comp_edges_to_indices,
compute_children_and_parents,
convert_point_process_to_distributed,
merge_cells,
remap_index_to_masked,
)
from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked
from jaxley.utils.syn_utils import gather_synapes


Expand Down
96 changes: 4 additions & 92 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,80 +409,11 @@ def query_channel_states_and_params(d, keys, idcs):
return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))


def convert_to_csc(
num_elements: int, row_ind: np.ndarray, col_ind: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Convert between two representations for sparse systems.
This is needed because `jax.scipy.linalg.spsolve` requires the `(ind, indptr)`
representation, but the `(row, col)` is more intuitive and easier to build.
This function uses `np` instead of `jnp` because it only deals with indexing which
can be dealt with only based on the branch structure (i.e. independent of any
parameter values).
Written by ChatGPT.
"""
data_inds = np.arange(num_elements)
# Step 1: Sort by (col_ind, row_ind)
sorted_indices = np.lexsort((row_ind, col_ind))
data_inds = data_inds[sorted_indices]
row_ind = row_ind[sorted_indices]
col_ind = col_ind[sorted_indices]

# Step 2: Create indptr array
n_cols = col_ind.max() + 1
indptr = np.zeros(n_cols + 1, dtype=int)
np.add.at(indptr, col_ind + 1, 1)
np.cumsum(indptr, out=indptr)

# Step 3: The row indices are already sorted
indices = row_ind

return data_inds, indices, indptr


def comp_edges_to_indices(
comp_edges: pd.DataFrame,
) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Generates sparse matrix indices from the table of node edges.
This is only used for the `jax.sparse` voltage solver.
Args:
comp_edges: Dataframe with three columns (sink, source, type).
Returns:
n_nodes: The number of total nodes (including branchpoints).
data_inds: The indices to reorder the data.
indices and indptr: Indices passed to the sparse matrix solver.
"""
# Build indices for diagonals.
sources = np.asarray(comp_edges["source"].to_list())
sinks = np.asarray(comp_edges["sink"].to_list())
n_nodes = np.max(sinks) + 1 if len(sinks) > 0 else 1
diagonal_inds = jnp.stack([jnp.arange(n_nodes), jnp.arange(n_nodes)])

# Build indices for off-diagonals.
off_diagonal_inds = jnp.stack([sources, sinks]).astype(int)

# Concatenate indices of diagonals and off-diagonals.
all_inds = jnp.concatenate([diagonal_inds, off_diagonal_inds], axis=1)

# Cast (row, col) indices to the format required for the `jax` sparse solver.
data_inds, indices, indptr = convert_to_csc(
num_elements=all_inds.shape[1],
row_ind=all_inds[0],
col_ind=all_inds[1],
)
return n_nodes, data_inds, indices, indptr


def compute_axial_conductances(
comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]
) -> jnp.ndarray:
"""Given `comp_edges`, radius, length, r_a, compute the axial conductances."""
# `Compartment-to-compartment` (c2c) conductances.
# `Compartment-to-compartment` (c2c) axial coupling conductances.
condition = comp_edges["type"].to_numpy() == 0
source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())
Expand All @@ -499,7 +430,7 @@ def compute_axial_conductances(
else:
conds_c2c = jnp.asarray([])

# `branchpoint-to-compartment` (bp2c) conductances.
# `branchpoint-to-compartment` (bp2c) axial coupling conductances.
condition = np.isin(comp_edges["type"].to_numpy(), [1, 2])
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())

Expand All @@ -512,7 +443,7 @@ def compute_axial_conductances(
else:
conds_bp2c = jnp.asarray([])

# `compartment-to-branchpoint` (c2bp) conductances.
# `compartment-to-branchpoint` (c2bp) axial coupling conductances.
condition = np.isin(comp_edges["type"].to_numpy(), [3, 4])
source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())

Expand All @@ -528,7 +459,7 @@ def compute_axial_conductances(
else:
conds_c2bp = jnp.asarray([])

# All conductances.
# All axial coupling conductances.
return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])


Expand All @@ -541,22 +472,3 @@ def compute_children_and_parents(
child_belongs_to_branchpoint = remap_to_consecutive(par_inds)
par_inds = np.unique(par_inds)
return par_inds, child_inds, child_belongs_to_branchpoint


def remap_index_to_masked(
index, nodes: pd.DataFrame, max_nseg: int, nseg_per_branch: jnp.ndarray
):
"""Convert actual index of the compartment to the index in the masked system.
E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the
masked `nsegs` are `[4, 4]`.
"""
cumsum_nseg_per_branch = jnp.concatenate(
[
jnp.asarray([0]),
jnp.cumsum(nseg_per_branch),
]
)
branch_inds = nodes.loc[index, "branch_index"].to_numpy()
remainders = index - cumsum_nseg_per_branch[branch_inds]
return branch_inds * max_nseg + remainders
96 changes: 96 additions & 0 deletions jaxley/utils/solver_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd


def remap_index_to_masked(
index, nodes: pd.DataFrame, max_nseg: int, nseg_per_branch: jnp.ndarray
):
"""Convert actual index of the compartment to the index in the masked system.
E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the
masked `nsegs` are `[4, 4]`.
"""
cumsum_nseg_per_branch = jnp.concatenate(
[
jnp.asarray([0]),
jnp.cumsum(nseg_per_branch),
]
)
branch_inds = nodes.loc[index, "branch_index"].to_numpy()
remainders = index - cumsum_nseg_per_branch[branch_inds]
return branch_inds * max_nseg + remainders


def convert_to_csc(
num_elements: int, row_ind: np.ndarray, col_ind: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Convert between two representations for sparse systems.
This is needed because `jax.scipy.linalg.spsolve` requires the `(ind, indptr)`
representation, but the `(row, col)` is more intuitive and easier to build.
This function uses `np` instead of `jnp` because it only deals with indexing which
can be dealt with only based on the branch structure (i.e. independent of any
parameter values).
Written by ChatGPT.
"""
data_inds = np.arange(num_elements)
# Step 1: Sort by (col_ind, row_ind)
sorted_indices = np.lexsort((row_ind, col_ind))
data_inds = data_inds[sorted_indices]
row_ind = row_ind[sorted_indices]
col_ind = col_ind[sorted_indices]

# Step 2: Create indptr array
n_cols = col_ind.max() + 1
indptr = np.zeros(n_cols + 1, dtype=int)
np.add.at(indptr, col_ind + 1, 1)
np.cumsum(indptr, out=indptr)

# Step 3: The row indices are already sorted
indices = row_ind

return data_inds, indices, indptr


def comp_edges_to_indices(
comp_edges: pd.DataFrame,
) -> Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Generates sparse matrix indices from the table of node edges.
This is only used for the `jax.sparse` voltage solver.
Args:
comp_edges: Dataframe with three columns (sink, source, type).
Returns:
n_nodes: The number of total nodes (including branchpoints).
data_inds: The indices to reorder the data.
indices and indptr: Indices passed to the sparse matrix solver.
"""
# Build indices for diagonals.
sources = np.asarray(comp_edges["source"].to_list())
sinks = np.asarray(comp_edges["sink"].to_list())
n_nodes = np.max(sinks) + 1 if len(sinks) > 0 else 1
diagonal_inds = jnp.stack([jnp.arange(n_nodes), jnp.arange(n_nodes)])

# Build indices for off-diagonals.
off_diagonal_inds = jnp.stack([sources, sinks]).astype(int)

# Concatenate indices of diagonals and off-diagonals.
all_inds = jnp.concatenate([diagonal_inds, off_diagonal_inds], axis=1)

# Cast (row, col) indices to the format required for the `jax` sparse solver.
data_inds, indices, indptr = convert_to_csc(
num_elements=all_inds.shape[1],
row_ind=all_inds[0],
col_ind=all_inds[1],
)
return n_nodes, data_inds, indices, indptr

0 comments on commit 54b45a5

Please sign in to comment.