Skip to content

Commit

Permalink
Use the new types to get rid of init_conds_custom
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Sep 20, 2024
1 parent 2c71bf0 commit 9e782e3
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 376 deletions.
131 changes: 50 additions & 81 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import networkx as nx
import numpy as np
import pandas as pd
from jax import jit, vmap
Expand All @@ -17,13 +16,14 @@
from jaxley.channels import Channel
from jaxley.solver_voltage import (
step_voltage_explicit,
step_voltage_implicit_with_custom_spsolve,
step_voltage_implicit_with_jax_spsolve,
step_voltage_implicit_with_jaxley_spsolve,
)
from jaxley.synapses import Synapse
from jaxley.utils.cell_utils import (
_compute_index_of_child,
_compute_num_children,
compute_axial_conductances,
compute_levels,
convert_point_process_to_distributed,
convert_to_csc,
Expand Down Expand Up @@ -258,7 +258,7 @@ def _show(

def init_morph(self):
"""Initialize the morphology such that it can be processed by the solvers."""
self._init_morph_custom_spsolve()
self._init_morph_jaxley_spsolve()
self._init_morph_jax_spsolve()
self.initialized_morph = True

Expand All @@ -268,40 +268,13 @@ def _init_morph_jax_spsolve(self):
raise NotImplementedError

@abstractmethod
def _init_morph_custom_spsolve(self):
def _init_morph_jaxley_spsolve(self):
"""Initialize the morphology for the custom Jaxley solver."""
raise NotImplementedError

def init_conds(self, params: Dict, voltage_solver: str):
def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):
"""Given radius, length, r_a, compute the axial coupling conductances."""
if voltage_solver.startswith("jaxley"):
all_conds = self._init_conds_custom_spsolve(params)
conds = self._init_conds_jax_spsolve(params)
for key in conds.keys():
all_conds[key] = conds[key]
return all_conds
else:
return self._init_conds_jax_spsolve(params)

@abstractmethod
def _init_conds_jax_spsolve(self, params: Dict):
"""Initialize coupling conductances.
Args:
params: Conductances and morphology parameters, not yet including
coupling conductances.
"""
raise NotImplementedError

@abstractmethod
def _init_conds_custom_spsolve(self, params: Dict):
"""Initialize coupling conductances.
Args:
params: Conductances and morphology parameters, not yet including
coupling conductances.
"""
raise NotImplementedError
return compute_axial_conductances(self._comp_edges, params)

def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"):
"""Adds channel nodes from constituents to `self.channel_nodes`."""
Expand Down Expand Up @@ -529,9 +502,9 @@ def get_all_parameters(
) -> Dict[str, jnp.ndarray]:
"""Return all parameters (and coupling conductances) needed to simulate.
Runs `init_conds()` and return every parameter that is needed to solve the ODE.
This includes conductances, radiuses, lengths, axial_resistivities, but also
coupling conductances.
Runs `_compute_axial_conductances()` and return every parameter that is needed
to solve the ODE. This includes conductances, radiuses, lengths,
axial_resistivities, but also coupling conductances.
This is done by first obtaining the current value of every parameter (not only
the trainable ones) and then replacing the trainable ones with the value
Expand Down Expand Up @@ -580,11 +553,8 @@ def get_all_parameters(
# `.set()` to work. This is done with `[:, None]`.
params[key] = params[key].at[inds].set(set_param[:, None])

# Compute conductance params and append them.
cond_params = self.init_conds(params=params, voltage_solver=voltage_solver)
for key in cond_params:
params[key] = cond_params[key]

# Compute conductance params and add them to the params dictionary.
params["axial_conductances"] = self._compute_axial_conductances(params=params)
return params

def get_states_from_nodes_and_edges(self):
Expand Down Expand Up @@ -967,19 +937,26 @@ def step(
# Voltage steps.
cm = params["capacitance"] # Abbreviation.

# Arguments used by all solvers.
solver_kwargs = {
"voltages": voltages,
"voltage_terms": (v_terms + syn_v_terms) / cm,
"constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
"axial_conductances": params["axial_conductances"],
"internal_node_inds": self._internal_node_inds,
}

# Add solver specific arguments.
if voltage_solver == "jax.sparse":
solver_kwargs = {
"voltages": voltages,
"voltage_terms": (v_terms + syn_v_terms) / cm,
"constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
"axial_conductances": params["axial_conductances"],
"data_inds": self._data_inds,
"indices": self._indices_jax_spsolve,
"indptr": self._indptr_jax_spsolve,
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"n_nodes": self._n_nodes,
"internal_node_inds": self._internal_node_inds,
}
solver_kwargs.update(
{
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"data_inds": self._data_inds,
"indices": self._indices_jax_spsolve,
"indptr": self._indptr_jax_spsolve,
"n_nodes": self._n_nodes,
}
)
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_jax_spsolve
else:
Expand All @@ -989,35 +966,27 @@ def step(
# Currently, the forward Euler solver also uses this format. However,
# this is only for historical reasons and we are planning to change this in
# the future.
solver_kwargs = {
"voltages": voltages,
"voltage_terms": (v_terms + syn_v_terms) / cm,
"constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
"branchpoint_conds_children": params["branchpoint_conds_children"],
"branchpoint_conds_parents": params["branchpoint_conds_parents"],
"branchpoint_weights_children": params["branchpoint_weights_children"],
"branchpoint_weights_parents": params["branchpoint_weights_parents"],
"axial_conductances": params["axial_conductances"],
"sources": np.asarray(self._comp_edges["source"].to_list()),
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"internal_node_inds": self._internal_node_inds,
"masked_node_inds": self._remapped_node_indices,
"n_nodes": self._n_nodes,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"debug_states": self.debug_states,
}
solver_kwargs.update(
{
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"sources": np.asarray(self._comp_edges["source"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"masked_node_inds": self._remapped_node_indices,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"debug_states": self.debug_states,
}
)
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_custom_spsolve
step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve

if solver == "bwd_euler":
u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)
Expand Down
69 changes: 2 additions & 67 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,10 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import vmap

from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.compartment import Compartment, CompartmentView
from jaxley.utils.cell_utils import (
comp_edges_to_indices,
compute_axial_conductances,
compute_children_and_parents,
compute_coupling_cond,
)
from jaxley.utils.cell_utils import comp_edges_to_indices, compute_children_and_parents


class Branch(Module):
Expand Down Expand Up @@ -118,7 +112,7 @@ def __getattr__(self, key: str):
else:
raise KeyError(f"Key {key} not recognized.")

def _init_morph_custom_spsolve(self):
def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = np.asarray([]).astype(int)
self.root_inds = jnp.asarray([0])
self._remapped_node_indices = self._internal_node_inds
Expand Down Expand Up @@ -148,65 +142,6 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr

def _init_conds_custom_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
conds = self.init_branch_conds_custom_spsolve(
params["axial_resistivity"], params["radius"], params["length"], self.nseg
)
cond_params = {
"branchpoint_conds_children": jnp.asarray([]),
"branchpoint_conds_parents": jnp.asarray([]),
"branchpoint_weights_children": jnp.asarray([]),
"branchpoint_weights_parents": jnp.asarray([]),
}
cond_params["branch_lowers"] = conds[0]
cond_params["branch_uppers"] = conds[1]
cond_params["branch_diags"] = conds[2]

return cond_params

def _init_conds_jax_spsolve(self, params: Dict) -> Dict[str, jnp.ndarray]:
conds = compute_axial_conductances(self._comp_edges, params)
return {"axial_conductances": conds}

@staticmethod
def init_branch_conds_custom_spsolve(
axial_resistivity: jnp.ndarray,
radiuses: jnp.ndarray,
lengths: jnp.ndarray,
nseg: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Given an axial resisitivity, set the coupling conductances.
Args:
axial_resistivity: Axial resistivity of each compartment.
radiuses: Radius of each compartment.
lengths: Length of each compartment.
nseg: Number of compartments in the branch.
Returns:
Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.
"""

# Compute coupling conductance for segments within a branch.
# `radius`: um
# `r_a`: ohm cm
# `length_single_compartment`: um
# `coupling_conds`: S * um / cm / um^2 = S / cm / um
r1 = radiuses[:-1]
r2 = radiuses[1:]
r_a1 = axial_resistivity[:-1]
r_a2 = axial_resistivity[1:]
l1 = lengths[:-1]
l2 = lengths[1:]
coupling_conds_bwd = compute_coupling_cond(r1, r2, r_a1, r_a2, l1, l2)
coupling_conds_fwd = compute_coupling_cond(r2, r1, r_a2, r_a1, l2, l1)

# Compute the summed coupling conductances of each compartment.
summed_coupling_conds = jnp.zeros((nseg))
summed_coupling_conds = summed_coupling_conds.at[1:].add(coupling_conds_fwd)
summed_coupling_conds = summed_coupling_conds.at[:-1].add(coupling_conds_bwd)
return coupling_conds_fwd, coupling_conds_bwd, summed_coupling_conds

def __len__(self) -> int:
return self.nseg

Expand Down
Loading

0 comments on commit 9e782e3

Please sign in to comment.