Skip to content

Commit aa4ae5f

Browse files
committed
wip: more tests passing, small refactor
1 parent e3e2000 commit aa4ae5f

File tree

1 file changed

+58
-54
lines changed

1 file changed

+58
-54
lines changed

jaxley/modules/base.py

+58-54
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
compute_axial_conductances,
3131
compute_current_density,
3232
compute_levels,
33-
interpolate_xyz,
34-
loc_of_index,
33+
interpolate_xyzr,
3534
params_to_pstate,
3635
query_states_and_params,
3736
v_interp,
@@ -228,6 +227,7 @@ def __getattr__(self, key):
228227
view._set_controlled_by_param(key) # overwrites param set by edge
229228
# Ensure synapse param sharing works with `edge`
230229
# `edge` will be removed as part of #463
230+
view.edges["local_edge_index"] = np.arange(len(view.edges))
231231
return view
232232

233233
def _childviews(self) -> List[str]:
@@ -1198,9 +1198,9 @@ def _get_state_names(self) -> Tuple[List, List]:
11981198
"""Collect all recordable / clampable states in the membrane and synapses.
11991199
12001200
Returns states seperated by comps and edges."""
1201-
channel_states = [name for c in self.channels for name in c.channel_states]
1201+
channel_states = [name for c in self.channels for name in c.states]
12021202
synapse_states = [
1203-
name for s in self.synapses if s is not None for name in s.synapse_states
1203+
name for s in self.synapses if s is not None for name in s.states
12041204
]
12051205
membrane_states = ["v", "i"] + self.membrane_current_names
12061206
return (
@@ -1219,6 +1219,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
12191219
"""
12201220
return self.trainable_params
12211221

1222+
@only_allow_module
1223+
def _iter_states_or_params(self, type="states") -> Dict[str, jnp.ndarray]:
1224+
# TODO FROM #447: MAKE THIS WORK FOR VIEW?
1225+
"""Return states as they are set in the `.nodes` and `.edges` tables."""
1226+
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
1227+
global_states = ["v"]
1228+
global_states_or_params = morph_params if type == "params" else global_states
1229+
for key in global_states_or_params:
1230+
yield key, self.base.jaxnodes["index"], self.base.jaxnodes[key]
1231+
1232+
# Join node and edge states into a single state dictionary.
1233+
for jax_arrays, mechs in zip(
1234+
[self.base.jaxnodes, self.base.jaxedges],
1235+
[self.base.channels, self.base.synapses],
1236+
):
1237+
for mech in mechs:
1238+
mech_inds = jax_arrays[mech._name]
1239+
for key in mech.__dict__[type]:
1240+
yield key, mech_inds, jax_arrays[key]
1241+
12221242
@only_allow_module
12231243
def get_all_parameters(
12241244
self, pstate: List[Dict], voltage_solver: str
@@ -1255,34 +1275,24 @@ def get_all_parameters(
12551275
Returns:
12561276
A dictionary of all module parameters.
12571277
"""
1258-
params = {}
1259-
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
1260-
for key in ["v"] + morph_params:
1261-
params[key] = self.base.jaxnodes[key]
1278+
pstate_inds = {d["key"]: i for i, d in enumerate(pstate)}
12621279

1263-
for jax_arrays, data, mechs in zip(
1264-
[self.base.jaxnodes, self.base.jaxedges],
1265-
[self.base.nodes, self.base.edges],
1266-
[self.base.channels, self.base.synapses],
1267-
):
1268-
for mech in mechs:
1269-
inds = jax_arrays[mech._name]
1270-
for mech_param in mech.params:
1271-
params[mech_param] = data[mech_param].to_numpy()
1272-
params[mech_param][inds] = jax_arrays[mech_param]
1273-
params[mech_param] = jnp.asarray(params[mech_param])
1280+
params = {}
1281+
for key, mech_inds, jax_array in self._iter_states_or_params("params"):
1282+
params[key] = jax_array
12741283

1275-
# Override with those parameters set by `.make_trainable()`.
1276-
for parameter in pstate:
1277-
key = parameter["key"]
1278-
inds = parameter["indices"]
1279-
set_param = parameter["val"]
1284+
# Override with those parameters set by `.make_trainable()`.
1285+
if key in pstate_inds:
1286+
idx = pstate_inds[key]
1287+
key = pstate[idx]["key"]
1288+
inds = pstate[idx]["indices"]
1289+
set_param = pstate[idx]["val"]
12801290

1281-
if key in params: # Only parameters, not initial states.
12821291
# `inds` is of shape `(num_params, num_comps_per_param)`.
12831292
# `set_param` is of shape `(num_params,)`
1284-
# We need to unsqueeze `set_param` to make it `(num_params, 1)` for the
1285-
# `.set()` to work. This is done with `[:, None]`.
1293+
# We need to unsqueeze `set_param` to make it `(num_params, 1)`
1294+
# for the `.set()` to work. This is done with `[:, None]`.
1295+
inds = np.searchsorted(mech_inds, inds)
12861296
params[key] = params[key].at[inds].set(set_param[:, None])
12871297

12881298
# Compute conductance params and add them to the params dictionary.
@@ -1291,20 +1301,6 @@ def get_all_parameters(
12911301
)
12921302
return params
12931303

1294-
@only_allow_module
1295-
def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:
1296-
# TODO FROM #447: MAKE THIS WORK FOR VIEW?
1297-
"""Return states as they are set in the `.nodes` and `.edges` tables."""
1298-
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
1299-
states = {"v": self.base.jaxnodes["v"]}
1300-
# Join node and edge states into a single state dictionary.
1301-
for channel in self.base.channels:
1302-
for channel_states in channel.states:
1303-
states[channel_states] = self.base.jaxnodes[channel_states]
1304-
for synapse_states in self.base.synapse_state_names:
1305-
states[synapse_states] = self.base.jaxedges[synapse_states]
1306-
return states
1307-
13081304
@only_allow_module
13091305
def get_all_states(
13101306
self, pstate: List[Dict], all_params, delta_t: float
@@ -1320,18 +1316,23 @@ def get_all_states(
13201316
Returns:
13211317
A dictionary of all states of the module.
13221318
"""
1323-
states = self.base._get_states_from_nodes_and_edges()
1324-
1325-
# Override with the initial states set by `.make_trainable()`.
1326-
for parameter in pstate:
1327-
key = parameter["key"]
1328-
inds = parameter["indices"]
1329-
set_param = parameter["val"]
1330-
if key in states: # Only initial states, not parameters.
1331-
# `inds` is of shape `(num_params, num_comps_per_param)`.
1332-
# `set_param` is of shape `(num_params,)`
1333-
# We need to unsqueeze `set_param` to make it `(num_params, 1)` for the
1334-
# `.set()` to work. This is done with `[:, None]`.
1319+
pstate_inds = {d["key"]: i for i, d in enumerate(pstate)}
1320+
states = {}
1321+
for key, mech_inds, jax_array in self._iter_states_or_params("states"):
1322+
states[key] = jax_array
1323+
1324+
# Override with those parameters set by `.make_trainable()`.
1325+
if key in pstate_inds:
1326+
idx = pstate_inds[key]
1327+
key = pstate[idx]["key"]
1328+
inds = pstate[idx]["indices"]
1329+
set_param = pstate[idx]["val"]
1330+
1331+
# `inds` is of shape `(num_states, num_comps_per_param)`.
1332+
# `set_param` is of shape `(num_states,)`
1333+
# We need to unsqueeze `set_param` to make it `(num_states, 1)`
1334+
# for the `.set()` to work. This is done with `[:, None]`.
1335+
inds = np.searchsorted(mech_inds, inds)
13351336
states[key] = states[key].at[inds].set(set_param[:, None])
13361337

13371338
# Add to the states the initial current through every channel.
@@ -1366,8 +1367,11 @@ def init_states(self, delta_t: float = 0.025):
13661367
delta_t: Passed on to `channel.init_state()`.
13671368
"""
13681369
# Update states of the channels.
1370+
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
13691371
channel_nodes = self.base.nodes
1370-
states = self.base._get_states_from_nodes_and_edges()
1372+
states = {}
1373+
for key, _, jax_array in self._iter_states_or_params("states"):
1374+
states[key] = jax_array
13711375

13721376
# We do not use any `pstate` for initializing. In principle, we could change
13731377
# that by allowing an input `params` and `pstate` to this function.

0 commit comments

Comments
 (0)