Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User exposing init and step for customization. #466

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 152 additions & 43 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from math import prod
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import pandas as pd
Expand All @@ -12,6 +12,151 @@
from jaxley.utils.jax_utils import nested_checkpoint_scan


def build_init_and_step_fn(
module: Module,
voltage_solver: str = "jaxley.stone",
solver: str = "bwd_euler",
) -> Tuple[Callable, Callable]:
"""This function returns the `init_fn` and `step_fn` which initialize the
parameters and states of the neuron model and then step through the model

Args:
module (Module): A `Module` object that e.g. a cell.
voltage_solver (str, optional): Voltage solver used in step. Defaults to "jaxley.stone".
solver (str, optional): ODE solver. Defaults to "bwd_euler".

Returns:
init_fn, step_fn: Functions that initialize the state and parameters, and perform
a single integration step, respectively.
"""
# Initialize the external inputs and their indices.
external_inds = module.external_inds.copy()

def init_fn(
params: List[Dict[str, jnp.ndarray]],
all_states: Optional[Dict] = None,
param_state: Optional[List[Dict]] = None,
delta_t: float = 0.025,
) -> Tuple[Dict, Dict]:
"""Initializes the parameters and states of the neuron model.

Args:
params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.
all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.
param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.
delta_t (float, optional): Step size. Defaults to 0.025.

Returns:
Tuple[Dict, Dict]: All states and parameters.
"""
# Make the `trainable_params` of the same shape as the `param_state`, such that
# they can be processed together by `get_all_parameters`.
pstate = params_to_pstate(params, module.indices_set_by_trainables)
if param_state is not None:
pstate += param_state

all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)
all_states = (
module.get_all_states(pstate, all_params, delta_t)
if all_states is None
else all_states
)
return all_states, all_params

def step_fn(
all_states: Dict,
all_params: Dict,
externals: Dict,
external_inds: Dict = external_inds,
delta_t: float = 0.025,
) -> Dict:
"""Performs a single integration step with step size delta_t.

Args:
all_states (Dict): Current state of the neuron model.
all_params (Dict): Current parameters of the neuron model.
externals (Dict): External inputs.
external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.
delta_t (float, optional): Time step. Defaults to 0.025.

Returns:
Dict: Updated states.
"""
state = all_states
state = module.step(
state,
delta_t,
external_inds,
externals,
params=all_params,
solver=solver,
voltage_solver=voltage_solver,
)
return state

return init_fn, step_fn


def add_stimuli(
externals: Dict,
external_inds: Dict,
data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
) -> Tuple[Dict, Dict]:
"""Extends the external inputs with the stimuli.

Args:
externals (Dict): Current external inputs.
external_inds (Dict): Current external indices.
data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.

Returns:
Tuple[Dict, Dict]: Updated external inputs and indices.
"""
# If stimulus is inserted, add it to the external inputs.
if "i" in externals.keys() or data_stimuli is not None:
if "i" in externals.keys():
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()

return externals, external_inds


def add_clamps(
externals: Dict,
external_inds: Dict,
data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,
) -> Tuple[Dict, Dict]:
"""Adds clamps to the external inputs.

Args:
externals (Dict): Current external inputs.
external_inds (Dict): Current external indices.
data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.

Returns:
Tuple[Dict, Dict]: Updated external inputs and indices.
"""
# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.global_comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.global_comp_index.to_numpy()

return externals, external_inds


def integrate(
module: Module,
params: List[Dict[str, jnp.ndarray]] = [],
Expand Down Expand Up @@ -70,28 +215,10 @@ def integrate(
external_inds = module.external_inds.copy()

# If stimulus is inserted, add it to the external inputs.
if "i" in module.externals.keys() or data_stimuli is not None:
if "i" in module.externals.keys():
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()
externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)

# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in module.externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.global_comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.global_comp_index.to_numpy()
externals, external_inds = add_clamps(externals, external_inds, data_clamps)

if not externals.keys():
# No stimulus was inserted and no clamp was set.
Expand Down Expand Up @@ -124,31 +251,13 @@ def integrate(
else:
externals[key] = externals[key][:t_max_steps, :]

# Make the `trainable_params` of the same shape as the `param_state`, such that they
# can be processed together by `get_all_parameters`.
pstate = params_to_pstate(params, module.indices_set_by_trainables)

# Gather parameters from `make_trainable` and `data_set` into a single list.
if param_state is not None:
pstate += param_state

all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)
all_states = (
module.get_all_states(pstate, all_params, delta_t)
if all_states is None
else all_states
init_fn, step_fn = build_init_and_step_fn(
module, voltage_solver=voltage_solver, solver=solver
)
all_states, all_params = init_fn(params, all_states, param_state, delta_t)

def _body_fun(state, externals):
state = module.step(
state,
delta_t,
external_inds,
externals,
params=all_params,
solver=solver,
voltage_solver=voltage_solver,
)
state = step_fn(state, all_params, externals, external_inds, delta_t)
recs = jnp.asarray(
[
state[rec_state][rec_ind]
Expand Down
Loading