Skip to content

Commit

Permalink
Bugfix for capacitances
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Sep 23, 2024
1 parent fe9a440 commit 7f39df7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
7 changes: 6 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,12 @@ def _step_channels_state(
indices = channel_nodes["comp_index"].to_numpy()
for channel in channels:
channel_param_names = list(channel.channel_params)
channel_param_names += ["radius", "length", "axial_resistivity"]
channel_param_names += [
"radius",
"length",
"axial_resistivity",
"capacitance",
]
channel_state_names = list(channel.channel_states)
channel_state_names += self.membrane_current_names
channel_indices = indices[channel_nodes[channel._name].astype(bool)]
Expand Down
34 changes: 22 additions & 12 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,20 +412,27 @@ def query_channel_states_and_params(d, keys, idcs):
def compute_axial_conductances(
comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]
) -> jnp.ndarray:
"""Given `comp_edges`, radius, length, r_a, compute the axial conductances."""
"""Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances.
Note that the resulting axial conductances will already by divided by the
capacitance `cm`.
"""
# `Compartment-to-compartment` (c2c) axial coupling conductances.
condition = comp_edges["type"].to_numpy() == 0
source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())

if len(sink_comp_inds) > 0:
conds_c2c = vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(
params["radius"][sink_comp_inds],
params["radius"][source_comp_inds],
params["axial_resistivity"][sink_comp_inds],
params["axial_resistivity"][source_comp_inds],
params["length"][sink_comp_inds],
params["length"][source_comp_inds],
conds_c2c = (
vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(
params["radius"][sink_comp_inds],
params["radius"][source_comp_inds],
params["axial_resistivity"][sink_comp_inds],
params["axial_resistivity"][source_comp_inds],
params["length"][sink_comp_inds],
params["length"][source_comp_inds],
)
/ params["capacitance"][sink_comp_inds]
)
else:
conds_c2c = jnp.asarray([])
Expand All @@ -435,10 +442,13 @@ def compute_axial_conductances(
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())

if len(sink_comp_inds) > 0:
conds_bp2c = vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(
params["radius"][sink_comp_inds],
params["axial_resistivity"][sink_comp_inds],
params["length"][sink_comp_inds],
conds_bp2c = (
vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(
params["radius"][sink_comp_inds],
params["axial_resistivity"][sink_comp_inds],
params["length"][sink_comp_inds],
)
/ params["capacitance"][sink_comp_inds]
)
else:
conds_bp2c = jnp.asarray([])
Expand Down
19 changes: 13 additions & 6 deletions tests/jaxley_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max, solver):

branch.set("length", 10.0)
branch.set("axial_resistivity", 1_000.0)
branch.set("capacitance", 5.0)

branch.set("HH_gNa", 0.120)
branch.set("HH_gK", 0.036)
Expand Down Expand Up @@ -90,6 +91,7 @@ def _run_neuron(i_delay, i_dur, i_amp, dt, t_max, solver):
branch.nseg = nseg_per_branch
branch.Ra = 1_000.0
branch.L = 10.0 * nseg_per_branch
branch.cm = 5.0

radiuses = np.linspace(3.0, 15.0, nseg_per_branch)
for i, comp in enumerate(branch):
Expand Down Expand Up @@ -163,13 +165,18 @@ def test_similarity_complex(solver):
0.9684275792140471,
0.8000000119209283,
]
voltages_jaxley = _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver)
voltages_neuron = _neuron_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver)
capacitances = np.linspace(1.0, 10.0, 16)
voltages_jaxley = _jaxley_complex(
i_delay, i_dur, i_amp, dt, t_max, diams, capacitances, solver
)
voltages_neuron = _neuron_complex(
i_delay, i_dur, i_amp, dt, t_max, diams, capacitances, solver
)

assert np.mean(np.abs(voltages_jaxley - voltages_neuron)) < 0.05


def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver):
def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams, capacitances, solver):
nseg = 16
comp = jx.Compartment()
branch = jx.Branch(comp, nseg)
Expand All @@ -196,10 +203,9 @@ def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver):
counter = 0
for loc in np.linspace(0, 1, nseg):
branch.loc(loc).set("radius", diams[counter] / 2)
branch.loc(loc).set("capacitance", capacitances[counter])
counter += 1

branch = branch

# 0.02 is fine here because nseg=8 for NEURON, but nseg=16 for jaxley.
branch.loc(0.02).stimulate(jx.step_current(i_delay, i_dur, i_amp, dt, t_max))
branch.loc(0.02).record()
Expand All @@ -210,7 +216,7 @@ def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver):
return s


def _neuron_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver):
def _neuron_complex(i_delay, i_dur, i_amp, dt, t_max, diams, capacitances, solver):
if solver == "bwd_euler":
h.secondorder = 0
elif solver == "crank_nicolson":
Expand Down Expand Up @@ -246,6 +252,7 @@ def _neuron_complex(i_delay, i_dur, i_amp, dt, t_max, diams, solver):

for i, seg in enumerate(sec):
seg.diam = diams[counter]
seg.cm = capacitances[counter]
counter += 1

# 0.05 is fine here because nseg=8, but nseg=16 for jaxley.
Expand Down
9 changes: 9 additions & 0 deletions tests/jaxley_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max, solver):
cell.set("radius", 5.0)
cell.set("length", 10.0)
cell.set("axial_resistivity", 1_000.0)
cell.set("capacitance", 7.0)

cell.set("HH_gNa", 0.120)
cell.set("HH_gK", 0.036)
Expand Down Expand Up @@ -94,6 +95,7 @@ def _run_neuron(i_delay, i_dur, i_amp, dt, t_max, solver):
sec.Ra = 1_000.0
sec.L = 10.0 * nseg_per_branch
sec.diam = 2 * 5.0
sec.cm = 7.0

sec.insert("hh")
sec.gnabar_hh = 0.120 # S/cm2
Expand Down Expand Up @@ -160,6 +162,8 @@ def _run_jaxley_unequal_ncomp(i_delay, i_dur, i_amp, dt, t_max):
cell.set("radius", 5.0)
cell.set("length", 20.0)
cell.set("axial_resistivity", 1_000.0)
cell.branch(1).set("capacitance", 10.0)
cell.branch(3).set("capacitance", 20.0)

cell.set("HH_gNa", 0.120)
cell.set("HH_gK", 0.036)
Expand Down Expand Up @@ -212,6 +216,11 @@ def _run_neuron_unequal_ncomp(i_delay, i_dur, i_amp, dt, t_max):
sec.ek = -77.0 # mV
sec.el_hh = -54.3 # mV

if i == 1:
sec.cm = 10.0
if i == 3:
sec.cm = 20.0

stim = h.IClamp(branch2(0.6)) # The second out of two.
stim.delay = i_delay
stim.dur = i_dur
Expand Down

0 comments on commit 7f39df7

Please sign in to comment.