30
30
compute_axial_conductances ,
31
31
compute_current_density ,
32
32
compute_levels ,
33
- interpolate_xyz ,
34
- loc_of_index ,
33
+ interpolate_xyzr ,
35
34
params_to_pstate ,
36
35
query_states_and_params ,
37
36
v_interp ,
@@ -228,6 +227,7 @@ def __getattr__(self, key):
228
227
view ._set_controlled_by_param (key ) # overwrites param set by edge
229
228
# Ensure synapse param sharing works with `edge`
230
229
# `edge` will be removed as part of #463
230
+ view .edges ["local_edge_index" ] = np .arange (len (view .edges ))
231
231
return view
232
232
233
233
def _childviews (self ) -> List [str ]:
@@ -1198,9 +1198,9 @@ def _get_state_names(self) -> Tuple[List, List]:
1198
1198
"""Collect all recordable / clampable states in the membrane and synapses.
1199
1199
1200
1200
Returns states seperated by comps and edges."""
1201
- channel_states = [name for c in self .channels for name in c .channel_states ]
1201
+ channel_states = [name for c in self .channels for name in c .states ]
1202
1202
synapse_states = [
1203
- name for s in self .synapses if s is not None for name in s .synapse_states
1203
+ name for s in self .synapses if s is not None for name in s .states
1204
1204
]
1205
1205
membrane_states = ["v" , "i" ] + self .membrane_current_names
1206
1206
return (
@@ -1219,6 +1219,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
1219
1219
"""
1220
1220
return self .trainable_params
1221
1221
1222
+ @only_allow_module
1223
+ def _iter_states_or_params (self , type = "states" ) -> Dict [str , jnp .ndarray ]:
1224
+ # TODO FROM #447: MAKE THIS WORK FOR VIEW?
1225
+ """Return states as they are set in the `.nodes` and `.edges` tables."""
1226
+ morph_params = ["radius" , "length" , "axial_resistivity" , "capacitance" ]
1227
+ global_states = ["v" ]
1228
+ global_states_or_params = morph_params if type == "params" else global_states
1229
+ for key in global_states_or_params :
1230
+ yield key , self .base .jaxnodes ["index" ], self .base .jaxnodes [key ]
1231
+
1232
+ # Join node and edge states into a single state dictionary.
1233
+ for jax_arrays , mechs in zip (
1234
+ [self .base .jaxnodes , self .base .jaxedges ],
1235
+ [self .base .channels , self .base .synapses ],
1236
+ ):
1237
+ for mech in mechs :
1238
+ mech_inds = jax_arrays [mech ._name ]
1239
+ for key in mech .__dict__ [type ]:
1240
+ yield key , mech_inds , jax_arrays [key ]
1241
+
1222
1242
@only_allow_module
1223
1243
def get_all_parameters (
1224
1244
self , pstate : List [Dict ], voltage_solver : str
@@ -1255,34 +1275,24 @@ def get_all_parameters(
1255
1275
Returns:
1256
1276
A dictionary of all module parameters.
1257
1277
"""
1258
- params = {}
1259
- morph_params = ["radius" , "length" , "axial_resistivity" , "capacitance" ]
1260
- for key in ["v" ] + morph_params :
1261
- params [key ] = self .base .jaxnodes [key ]
1278
+ pstate_inds = {d ["key" ]: i for i , d in enumerate (pstate )}
1262
1279
1263
- for jax_arrays , data , mechs in zip (
1264
- [self .base .jaxnodes , self .base .jaxedges ],
1265
- [self .base .nodes , self .base .edges ],
1266
- [self .base .channels , self .base .synapses ],
1267
- ):
1268
- for mech in mechs :
1269
- inds = jax_arrays [mech ._name ]
1270
- for mech_param in mech .params :
1271
- params [mech_param ] = data [mech_param ].to_numpy ()
1272
- params [mech_param ][inds ] = jax_arrays [mech_param ]
1273
- params [mech_param ] = jnp .asarray (params [mech_param ])
1280
+ params = {}
1281
+ for key , mech_inds , jax_array in self ._iter_states_or_params ("params" ):
1282
+ params [key ] = jax_array
1274
1283
1275
- # Override with those parameters set by `.make_trainable()`.
1276
- for parameter in pstate :
1277
- key = parameter ["key" ]
1278
- inds = parameter ["indices" ]
1279
- set_param = parameter ["val" ]
1284
+ # Override with those parameters set by `.make_trainable()`.
1285
+ if key in pstate_inds :
1286
+ idx = pstate_inds [key ]
1287
+ key = pstate [idx ]["key" ]
1288
+ inds = pstate [idx ]["indices" ]
1289
+ set_param = pstate [idx ]["val" ]
1280
1290
1281
- if key in params : # Only parameters, not initial states.
1282
1291
# `inds` is of shape `(num_params, num_comps_per_param)`.
1283
1292
# `set_param` is of shape `(num_params,)`
1284
- # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the
1285
- # `.set()` to work. This is done with `[:, None]`.
1293
+ # We need to unsqueeze `set_param` to make it `(num_params, 1)`
1294
+ # for the `.set()` to work. This is done with `[:, None]`.
1295
+ inds = np .searchsorted (mech_inds , inds )
1286
1296
params [key ] = params [key ].at [inds ].set (set_param [:, None ])
1287
1297
1288
1298
# Compute conductance params and add them to the params dictionary.
@@ -1291,20 +1301,6 @@ def get_all_parameters(
1291
1301
)
1292
1302
return params
1293
1303
1294
- @only_allow_module
1295
- def _get_states_from_nodes_and_edges (self ) -> Dict [str , jnp .ndarray ]:
1296
- # TODO FROM #447: MAKE THIS WORK FOR VIEW?
1297
- """Return states as they are set in the `.nodes` and `.edges` tables."""
1298
- self .base .to_jax () # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
1299
- states = {"v" : self .base .jaxnodes ["v" ]}
1300
- # Join node and edge states into a single state dictionary.
1301
- for channel in self .base .channels :
1302
- for channel_states in channel .states :
1303
- states [channel_states ] = self .base .jaxnodes [channel_states ]
1304
- for synapse_states in self .base .synapse_state_names :
1305
- states [synapse_states ] = self .base .jaxedges [synapse_states ]
1306
- return states
1307
-
1308
1304
@only_allow_module
1309
1305
def get_all_states (
1310
1306
self , pstate : List [Dict ], all_params , delta_t : float
@@ -1320,18 +1316,23 @@ def get_all_states(
1320
1316
Returns:
1321
1317
A dictionary of all states of the module.
1322
1318
"""
1323
- states = self .base ._get_states_from_nodes_and_edges ()
1324
-
1325
- # Override with the initial states set by `.make_trainable()`.
1326
- for parameter in pstate :
1327
- key = parameter ["key" ]
1328
- inds = parameter ["indices" ]
1329
- set_param = parameter ["val" ]
1330
- if key in states : # Only initial states, not parameters.
1331
- # `inds` is of shape `(num_params, num_comps_per_param)`.
1332
- # `set_param` is of shape `(num_params,)`
1333
- # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the
1334
- # `.set()` to work. This is done with `[:, None]`.
1319
+ pstate_inds = {d ["key" ]: i for i , d in enumerate (pstate )}
1320
+ states = {}
1321
+ for key , mech_inds , jax_array in self ._iter_states_or_params ("states" ):
1322
+ states [key ] = jax_array
1323
+
1324
+ # Override with those parameters set by `.make_trainable()`.
1325
+ if key in pstate_inds :
1326
+ idx = pstate_inds [key ]
1327
+ key = pstate [idx ]["key" ]
1328
+ inds = pstate [idx ]["indices" ]
1329
+ set_param = pstate [idx ]["val" ]
1330
+
1331
+ # `inds` is of shape `(num_states, num_comps_per_param)`.
1332
+ # `set_param` is of shape `(num_states,)`
1333
+ # We need to unsqueeze `set_param` to make it `(num_states, 1)`
1334
+ # for the `.set()` to work. This is done with `[:, None]`.
1335
+ inds = np .searchsorted (mech_inds , inds )
1335
1336
states [key ] = states [key ].at [inds ].set (set_param [:, None ])
1336
1337
1337
1338
# Add to the states the initial current through every channel.
@@ -1366,8 +1367,11 @@ def init_states(self, delta_t: float = 0.025):
1366
1367
delta_t: Passed on to `channel.init_state()`.
1367
1368
"""
1368
1369
# Update states of the channels.
1370
+ self .base .to_jax () # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
1369
1371
channel_nodes = self .base .nodes
1370
- states = self .base ._get_states_from_nodes_and_edges ()
1372
+ states = {}
1373
+ for key , _ , jax_array in self ._iter_states_or_params ("states" ):
1374
+ states [key ] = jax_array
1371
1375
1372
1376
# We do not use any `pstate` for initializing. In principle, we could change
1373
1377
# that by allowing an input `params` and `pstate` to this function.
0 commit comments