Skip to content

Commit 41ce5c3

Browse files
Bugfix for init_states() when channel does not exist in all comps (#421)
* Bugfix for using `init_states()` when channel does not exist in all comps * add jaxley-mech to dev dependencies for testing * add test for complex channel init_states * bugfix for workflow * adapt tutorial to #416 * formatting * bugfix for pyproject.toml
1 parent c2c6687 commit 41ce5c3

File tree

7 files changed

+154
-37
lines changed

7 files changed

+154
-37
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Install dependencies
2727
run: |
2828
python -m pip install --upgrade pip
29-
pip install -e .[dev]
29+
pip install -e ".[dev]"
3030
3131
- name: Check formatting with black
3232
run: |

jaxley/channels/channel.py

+10
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,13 @@ def compute_current(
7979
Current in `uA/cm2`.
8080
"""
8181
raise NotImplementedError
82+
83+
def init_state(
84+
self,
85+
states: Dict[str, jnp.ndarray],
86+
v: jnp.ndarray,
87+
params: Dict[str, jnp.ndarray],
88+
delta_t: float,
89+
):
90+
"""Initialize states of channel."""
91+
return {}

jaxley/modules/base.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
convert_point_process_to_distributed,
2525
interpolate_xyz,
2626
loc_of_index,
27+
query_channel_states_and_params,
2728
v_interp,
2829
)
2930
from jaxley.utils.debug_solver import compute_morphology_indices, convert_to_csc
@@ -608,25 +609,37 @@ def init_states(self, delta_t: float = 0.025):
608609
channel_nodes = self.nodes
609610
states = self.get_states_from_nodes_and_edges()
610611

612+
# We do not use any `pstate` for initializing. In principle, we could change
613+
# that by allowing an input `params` and `pstate` to this function.
614+
params = self.get_all_parameters([])
615+
611616
for channel in self.channels:
612617
name = channel._name
613-
indices = channel_nodes.loc[channel_nodes[name]]["comp_index"].to_numpy()
614-
voltages = channel_nodes.loc[indices, "v"].to_numpy()
618+
channel_indices = channel_nodes.loc[channel_nodes[name]][
619+
"comp_index"
620+
].to_numpy()
621+
voltages = channel_nodes.loc[channel_indices, "v"].to_numpy()
615622

616623
channel_param_names = list(channel.channel_params.keys())
617-
channel_params = {}
618-
for p in channel_param_names:
619-
channel_params[p] = channel_nodes[p][indices].to_numpy()
624+
channel_state_names = list(channel.channel_states.keys())
625+
channel_states = query_channel_states_and_params(
626+
states, channel_state_names, channel_indices
627+
)
628+
channel_params = query_channel_states_and_params(
629+
params, channel_param_names, channel_indices
630+
)
620631

621-
init_state = channel.init_state(states, voltages, channel_params, delta_t)
632+
init_state = channel.init_state(
633+
channel_states, voltages, channel_params, delta_t
634+
)
622635

623636
# `init_state` might not return all channel states. Only the ones that are
624637
# returned are updated here.
625638
for key, val in init_state.items():
626639
# Note that we are overriding `self.nodes` here, but `self.nodes` is
627640
# not used above to actually compute the current states (so there are
628641
# no issues with overriding states).
629-
self.nodes.loc[indices, key] = val
642+
self.nodes.loc[channel_indices, key] = val
630643

631644
def _init_morph_for_debugging(self):
632645
"""Instandiates row and column inds which can be used to solve the voltage eqs.
@@ -982,11 +995,6 @@ def _step_channels_state(
982995
"""One integration step of the channels."""
983996
voltages = states["v"]
984997

985-
query = lambda d, keys, idcs: dict(
986-
zip(keys, (v[idcs] for v in map(d.get, keys)))
987-
) # get dict with subset of keys and values from d
988-
# only loops over necessary keys, as opposed to looping over d.items()
989-
990998
# Update states of the channels.
991999
indices = channel_nodes["comp_index"].to_numpy()
9921000
for channel in channels:
@@ -996,8 +1004,12 @@ def _step_channels_state(
9961004
channel_state_names += self.membrane_current_names
9971005
channel_indices = indices[channel_nodes[channel._name].astype(bool)]
9981006

999-
channel_params = query(params, channel_param_names, channel_indices)
1000-
channel_states = query(states, channel_state_names, channel_indices)
1007+
channel_params = query_channel_states_and_params(
1008+
params, channel_param_names, channel_indices
1009+
)
1010+
channel_states = query_channel_states_and_params(
1011+
states, channel_state_names, channel_indices
1012+
)
10011013

10021014
states_updated = channel.update_states(
10031015
channel_states, delta_t, voltages[channel_indices], channel_params

jaxley/utils/cell_utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,18 @@ def group_and_sum(
396396
group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)
397397

398398
return group_sums
399+
400+
401+
def query_channel_states_and_params(d, keys, idcs):
402+
"""Get dict with subset of keys and values from d.
403+
404+
This is used to restrict a dict where every item contains __all__ states to only
405+
the ones that are relevant for the channel. E.g.
406+
407+
```states = {'eCa': Array([ 0., 0., nan]}```
408+
409+
will be
410+
```states = {'eCa': Array([ 0., 0.]}```
411+
412+
Only loops over necessary keys, as opposed to looping over `d.items()`."""
413+
return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ doc = [
5050
dev = [
5151
"black",
5252
"isort",
53+
"jaxley-mech",
5354
"neuron",
5455
"pytest",
5556
"pyright",

tests/test_channels.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
jax.config.update("jax_enable_x64", True)
77
jax.config.update("jax_platform_name", "cpu")
8-
from typing import Optional
8+
from typing import Dict, Optional
99

1010
import jax.numpy as jnp
1111
import numpy as np
1212
import pytest
13+
from jaxley_mech.channels.l5pc import CaNernstReversal, CaPump
1314

1415
import jaxley as jx
1516
from jaxley.channels import HH, CaL, CaT, Channel, K, Km, Leak, Na
17+
from jaxley.solver_gate import save_exp, solve_inf_gate_exponential
1618

1719

1820
def test_channel_set_name():
@@ -101,6 +103,92 @@ def test_init_states():
101103
assert np.abs(v[0, 0] - v[0, -1]) < 0.02
102104

103105

106+
class KCA11(Channel):
107+
def __init__(self, name: Optional[str] = None):
108+
super().__init__(name)
109+
prefix = self._name
110+
self.channel_params = {
111+
f"{prefix}_q10_ch": 3,
112+
f"{prefix}_q10_ch0": 22,
113+
"celsius": 22,
114+
}
115+
self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4}
116+
self.current_name = f"i_K"
117+
118+
def update_states(
119+
self,
120+
states: Dict[str, jnp.ndarray],
121+
dt,
122+
v,
123+
params: Dict[str, jnp.ndarray],
124+
):
125+
"""Update state."""
126+
prefix = self._name
127+
m = states[f"{prefix}_m"]
128+
q10 = params[f"{prefix}_q10_ch"] ** (
129+
(params["celsius"] - params[f"{prefix}_q10_ch0"]) / 10
130+
)
131+
cai = states["CaCon_i"]
132+
new_m = solve_inf_gate_exponential(m, dt, *self.m_gate(v, cai, q10))
133+
return {f"{prefix}_m": new_m}
134+
135+
def compute_current(
136+
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
137+
):
138+
"""Return current."""
139+
prefix = self._name
140+
m = states[f"{prefix}_m"]
141+
g = 0.03 * m * 1000 # mS/cm^2
142+
return g * (v + 80.0)
143+
144+
def init_state(self, states, v, params, dt):
145+
"""Initialize the state such at fixed point of gate dynamics."""
146+
prefix = self._name
147+
q10 = params[f"{prefix}_q10_ch"] ** (
148+
(params["celsius"] - params[f"{prefix}_q10_ch0"]) / 10
149+
)
150+
cai = states["CaCon_i"]
151+
m_inf, _ = self.m_gate(v, cai, q10)
152+
return {f"{prefix}_m": m_inf}
153+
154+
@staticmethod
155+
def m_gate(v, cai, q10):
156+
cai = cai * 1e3
157+
v_half = -66 + 137 * save_exp(-0.3044 * cai) + 30.24 * save_exp(-0.04141 * cai)
158+
alpha = 25.0
159+
160+
beta = 0.075 / save_exp((v - v_half) / 10)
161+
m_inf = alpha / (alpha + beta)
162+
tau_m = 1.0 * q10
163+
return m_inf, tau_m
164+
165+
166+
def test_init_states_complex_channel():
167+
"""Test for `init_states()` with a more complicated channel model.
168+
169+
The channel model used for this test uses the `states` in `init_state` and it also
170+
uses `q10`. The model inserts the channel only is some branches. This test follows
171+
an issue I had with Jaxley in v0.2.0 (fixed in v0.2.1).
172+
"""
173+
## Create cell
174+
comp = jx.Compartment()
175+
branch = jx.Branch(comp, nseg=1)
176+
cell = jx.Cell(branch, parents=[-1, 0, 0])
177+
178+
# CA channels.
179+
cell.branch([0, 1]).insert(CaNernstReversal())
180+
cell.branch([0, 1]).insert(CaPump())
181+
cell.branch([0, 1]).insert(KCA11())
182+
183+
cell.init_states()
184+
185+
current = jx.step_current(1.0, 1.0, 0.1, 0.025, 3.0)
186+
cell.branch(2).comp(0).stimulate(current)
187+
cell.branch(2).comp(0).record()
188+
voltages = jx.integrate(cell)
189+
assert np.invert(np.any(np.isnan(voltages))), "NaN voltage found"
190+
191+
104192
def test_multiple_channel_currents():
105193
"""Test whether all channels can"""
106194

0 commit comments

Comments
 (0)