|
5 | 5 |
|
6 | 6 | jax.config.update("jax_enable_x64", True)
|
7 | 7 | jax.config.update("jax_platform_name", "cpu")
|
8 |
| -from typing import Optional |
| 8 | +from typing import Dict, Optional |
9 | 9 |
|
10 | 10 | import jax.numpy as jnp
|
11 | 11 | import numpy as np
|
12 | 12 | import pytest
|
| 13 | +from jaxley_mech.channels.l5pc import CaNernstReversal, CaPump |
13 | 14 |
|
14 | 15 | import jaxley as jx
|
15 | 16 | from jaxley.channels import HH, CaL, CaT, Channel, K, Km, Leak, Na
|
| 17 | +from jaxley.solver_gate import save_exp, solve_inf_gate_exponential |
16 | 18 |
|
17 | 19 |
|
18 | 20 | def test_channel_set_name():
|
@@ -101,6 +103,92 @@ def test_init_states():
|
101 | 103 | assert np.abs(v[0, 0] - v[0, -1]) < 0.02
|
102 | 104 |
|
103 | 105 |
|
| 106 | +class KCA11(Channel): |
| 107 | + def __init__(self, name: Optional[str] = None): |
| 108 | + super().__init__(name) |
| 109 | + prefix = self._name |
| 110 | + self.channel_params = { |
| 111 | + f"{prefix}_q10_ch": 3, |
| 112 | + f"{prefix}_q10_ch0": 22, |
| 113 | + "celsius": 22, |
| 114 | + } |
| 115 | + self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} |
| 116 | + self.current_name = f"i_K" |
| 117 | + |
| 118 | + def update_states( |
| 119 | + self, |
| 120 | + states: Dict[str, jnp.ndarray], |
| 121 | + dt, |
| 122 | + v, |
| 123 | + params: Dict[str, jnp.ndarray], |
| 124 | + ): |
| 125 | + """Update state.""" |
| 126 | + prefix = self._name |
| 127 | + m = states[f"{prefix}_m"] |
| 128 | + q10 = params[f"{prefix}_q10_ch"] ** ( |
| 129 | + (params["celsius"] - params[f"{prefix}_q10_ch0"]) / 10 |
| 130 | + ) |
| 131 | + cai = states["CaCon_i"] |
| 132 | + new_m = solve_inf_gate_exponential(m, dt, *self.m_gate(v, cai, q10)) |
| 133 | + return {f"{prefix}_m": new_m} |
| 134 | + |
| 135 | + def compute_current( |
| 136 | + self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] |
| 137 | + ): |
| 138 | + """Return current.""" |
| 139 | + prefix = self._name |
| 140 | + m = states[f"{prefix}_m"] |
| 141 | + g = 0.03 * m * 1000 # mS/cm^2 |
| 142 | + return g * (v + 80.0) |
| 143 | + |
| 144 | + def init_state(self, states, v, params, dt): |
| 145 | + """Initialize the state such at fixed point of gate dynamics.""" |
| 146 | + prefix = self._name |
| 147 | + q10 = params[f"{prefix}_q10_ch"] ** ( |
| 148 | + (params["celsius"] - params[f"{prefix}_q10_ch0"]) / 10 |
| 149 | + ) |
| 150 | + cai = states["CaCon_i"] |
| 151 | + m_inf, _ = self.m_gate(v, cai, q10) |
| 152 | + return {f"{prefix}_m": m_inf} |
| 153 | + |
| 154 | + @staticmethod |
| 155 | + def m_gate(v, cai, q10): |
| 156 | + cai = cai * 1e3 |
| 157 | + v_half = -66 + 137 * save_exp(-0.3044 * cai) + 30.24 * save_exp(-0.04141 * cai) |
| 158 | + alpha = 25.0 |
| 159 | + |
| 160 | + beta = 0.075 / save_exp((v - v_half) / 10) |
| 161 | + m_inf = alpha / (alpha + beta) |
| 162 | + tau_m = 1.0 * q10 |
| 163 | + return m_inf, tau_m |
| 164 | + |
| 165 | + |
| 166 | +def test_init_states_complex_channel(): |
| 167 | + """Test for `init_states()` with a more complicated channel model. |
| 168 | +
|
| 169 | + The channel model used for this test uses the `states` in `init_state` and it also |
| 170 | + uses `q10`. The model inserts the channel only is some branches. This test follows |
| 171 | + an issue I had with Jaxley in v0.2.0 (fixed in v0.2.1). |
| 172 | + """ |
| 173 | + ## Create cell |
| 174 | + comp = jx.Compartment() |
| 175 | + branch = jx.Branch(comp, nseg=1) |
| 176 | + cell = jx.Cell(branch, parents=[-1, 0, 0]) |
| 177 | + |
| 178 | + # CA channels. |
| 179 | + cell.branch([0, 1]).insert(CaNernstReversal()) |
| 180 | + cell.branch([0, 1]).insert(CaPump()) |
| 181 | + cell.branch([0, 1]).insert(KCA11()) |
| 182 | + |
| 183 | + cell.init_states() |
| 184 | + |
| 185 | + current = jx.step_current(1.0, 1.0, 0.1, 0.025, 3.0) |
| 186 | + cell.branch(2).comp(0).stimulate(current) |
| 187 | + cell.branch(2).comp(0).record() |
| 188 | + voltages = jx.integrate(cell) |
| 189 | + assert np.invert(np.any(np.isnan(voltages))), "NaN voltage found" |
| 190 | + |
| 191 | + |
104 | 192 | def test_multiple_channel_currents():
|
105 | 193 | """Test whether all channels can"""
|
106 | 194 |
|
|
0 commit comments