Skip to content

Commit

Permalink
Add option to use states in init_state
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Sep 13, 2024
1 parent 82cce2d commit e745484
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
2 changes: 1 addition & 1 deletion jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def compute_current(
+ gLeak * (v - params[f"{prefix}_eLeak"])
)

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(v)
Expand Down
12 changes: 6 additions & 6 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def compute_current(
gLeak = params[f"{prefix}_gLeak"] * 1000 # mS/cm^2
return gLeak * (v - params[f"{prefix}_eLeak"])

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
return {}


Expand Down Expand Up @@ -109,7 +109,7 @@ def compute_current(
current = gNa * (v - params["eNa"])
return current

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(v, params["vt"])
Expand Down Expand Up @@ -177,7 +177,7 @@ def compute_current(

return gK * (v - params["eK"])

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_n, beta_n = self.n_gate(v, params["vt"])
Expand Down Expand Up @@ -233,7 +233,7 @@ def compute_current(
gKm = params[f"{prefix}_gKm"] * p * 1000 # mS/cm^2
return gKm * (v - params["eK"])

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_p, beta_p = self.p_gate(v, params[f"{prefix}_taumax"])
Expand Down Expand Up @@ -288,7 +288,7 @@ def compute_current(

return gCaL * (v - params["eCa"])

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_q, beta_q = self.q_gate(v)
Expand Down Expand Up @@ -359,7 +359,7 @@ def compute_current(

return gCaT * (v - params["eCa"])

def init_state(self, v, params):
def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_u, beta_u = self.u_gate(v, params[f"{prefix}_vx"])
Expand Down
34 changes: 24 additions & 10 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,18 @@ def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]:

return params

def get_states_from_nodes_and_edges(self):
"""Return states as they are set in the `.nodes` and `.edges` tables."""
self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
states = {"v": self.jaxnodes["v"]}
# Join node and edge states into a single state dictionary.
for channel in self.channels:
for channel_states in channel.channel_states:
states[channel_states] = self.jaxnodes[channel_states]
for synapse_states in self.synapse_state_names:
states[synapse_states] = self.jaxedges[synapse_states]
return states

def get_all_states(
self, pstate: List[Dict], all_params, delta_t: float
) -> Dict[str, jnp.ndarray]:
Expand All @@ -549,13 +561,7 @@ def get_all_states(
Returns:
A dictionary of all states of the module.
"""
# Join node and edge states into a single state dictionary.
states = {"v": self.jaxnodes["v"]}
for channel in self.channels:
for channel_states in channel.channel_states:
states[channel_states] = self.jaxnodes[channel_states]
for synapse_states in self.synapse_state_names:
states[synapse_states] = self.jaxedges[synapse_states]
states = self.get_states_from_nodes_and_edges()

# Override with the initial states set by `.make_trainable()`.
for parameter in pstate:
Expand Down Expand Up @@ -590,12 +596,17 @@ def initialize(self):
self.init_morph()
return self

def init_states(self):
def init_states(self, delta_t: float = 0.025):
"""Initialize all mechanisms in their steady state.
This considers the voltages and parameters of each compartment."""
This considers the voltages and parameters of each compartment.
Args:
delta_t: Passed on to `channel.init_state()`.
"""
# Update states of the channels.
channel_nodes = self.nodes
states = self.get_states_from_nodes_and_edges()

for channel in self.channels:
name = channel._name
Expand All @@ -607,11 +618,14 @@ def init_states(self):
for p in channel_param_names:
channel_params[p] = channel_nodes[p][indices].to_numpy()

init_state = channel.init_state(voltages, channel_params)
init_state = channel.init_state(states, voltages, channel_params, delta_t)

# `init_state` might not return all channel states. Only the ones that are
# returned are updated here.
for key, val in init_state.items():
# Note that we are overriding `self.nodes` here, but `self.nodes` is
# not used above to actually compute the current states (so there are
# no issues with overriding states).
self.nodes.loc[indices, key] = val

def _init_morph_for_debugging(self):
Expand Down

0 comments on commit e745484

Please sign in to comment.