Skip to content

Commit 03217be

Browse files
Allow changing nseg after initialization
1 parent 4db0a8f commit 03217be

File tree

7 files changed

+511
-38
lines changed

7 files changed

+511
-38
lines changed

jaxley/modules/base.py

+134
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import ABC, abstractmethod
66
from copy import deepcopy
77
from typing import Callable, Dict, List, Optional, Tuple, Union
8+
from warnings import warn
89

910
import jax.numpy as jnp
1011
import numpy as np
@@ -35,6 +36,7 @@
3536
from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
3637
from jaxley.utils.plot_utils import plot_morph
3738
from jaxley.utils.solver_utils import convert_to_csc
39+
from jaxley.utils.swc import build_radiuses_from_xyzr
3840

3941

4042
class Module(ABC):
@@ -109,6 +111,7 @@ def __init__(self):
109111

110112
# x, y, z coordinates and radius.
111113
self.xyzr: List[np.ndarray] = []
114+
self._radius_generating_fns = None # Defined by `.read_swc()`.
112115

113116
# For debugging the solver. Will be empty by default and only filled if
114117
# `self._init_morph_for_debugging` is run.
@@ -381,6 +384,137 @@ def _data_set(
381384
raise KeyError("Key not recognized.")
382385
return param_state
383386

387+
def _set_ncomp(
388+
self,
389+
ncomp: int,
390+
view: pd.DataFrame,
391+
all_nodes: pd.DataFrame,
392+
start_idx: int,
393+
nseg_per_branch: jnp.asarray,
394+
channel_names: List[str],
395+
channel_param_names: List[str],
396+
channel_state_names: List[str],
397+
radius_generating_fns: List[Callable],
398+
min_radius: Optional[float],
399+
):
400+
"""Set the number of compartments with which the branch is discretized."""
401+
within_branch_radiuses = view["radius"].to_numpy()
402+
compartment_lengths = view["length"].to_numpy()
403+
num_previous_ncomp = len(within_branch_radiuses)
404+
branch_indices = pd.unique(view["branch_index"])
405+
406+
error_msg = lambda name: (
407+
f"You previously modified the {name} of individual compartments, but "
408+
f"now you are modifying the number of compartments in this branch. "
409+
f"This is not allowed. First build the morphology with `set_ncomp()` and "
410+
f"then modify the radiuses and lengths of compartments."
411+
)
412+
413+
if (
414+
~np.all(within_branch_radiuses == within_branch_radiuses[0])
415+
and radius_generating_fns is None
416+
):
417+
raise ValueError(error_msg("radius"))
418+
419+
for property_name in ["length", "capacitance", "axial_resistivity"]:
420+
compartment_properties = view[property_name].to_numpy()
421+
if ~np.all(compartment_properties == compartment_properties[0]):
422+
raise ValueError(error_msg(property_name))
423+
424+
if not (view[channel_names].var() == 0.0).all():
425+
raise ValueError(
426+
"Some channel exists only in some compartments of the branch which you"
427+
"are trying to modify. This is not allowed. First specify the number"
428+
"of compartments with `.set_ncomp()` and then insert the channels"
429+
"accordingly."
430+
)
431+
432+
if not (view[channel_param_names + channel_state_names].var() == 0.0).all():
433+
raise ValueError(
434+
"Some channel has different parameters or states between the "
435+
"different compartments of the branch which you are trying to modify. "
436+
"This is not allowed. First specify the number of compartments with "
437+
"`.set_ncomp()` and then insert the channels accordingly."
438+
)
439+
440+
# Add new rows as the average of all rows. Special case for the length is below.
441+
average_row = view.mean(skipna=False)
442+
average_row = average_row.to_frame().T
443+
view = pd.concat([*[average_row] * ncomp], axis="rows")
444+
445+
# If the `view` is not the entire `Module`, but a `View` (i.e. if one changes
446+
# the number of comps within a branch of a cell), then the `self.pointer.view`
447+
# will contain the additional `global_xyz_index` columns. However, the
448+
# `self.nodes` will not have these columns.
449+
#
450+
# Note that we assert that there are no trainables, so `controlled_by_params`
451+
# of the `self.nodes` has to be empty.
452+
if "global_comp_index" in view.columns:
453+
view = view.drop(
454+
columns=[
455+
"global_comp_index",
456+
"global_branch_index",
457+
"global_cell_index",
458+
"controlled_by_param",
459+
]
460+
)
461+
462+
# Set the correct datatype after having performed an average which cast
463+
# everything to float.
464+
integer_cols = ["comp_index", "branch_index", "cell_index"]
465+
view[integer_cols] = view[integer_cols].astype(int)
466+
467+
# Whether or not a channel exists in a compartment is a boolean.
468+
boolean_cols = channel_names
469+
view[boolean_cols] = view[boolean_cols].astype(bool)
470+
471+
# Special treatment for the lengths and radiuses. These are not being set as
472+
# the average because we:
473+
# 1) Want to maintain the total length of a branch.
474+
# 2) Want to use the SWC inferred radius.
475+
#
476+
# Compute new compartment lengths.
477+
comp_lengths = np.sum(compartment_lengths) / ncomp
478+
view["length"] = comp_lengths
479+
480+
# Compute new compartment radiuses.
481+
if radius_generating_fns is not None:
482+
view["radius"] = build_radiuses_from_xyzr(
483+
radius_fns=radius_generating_fns,
484+
branch_indices=branch_indices,
485+
min_radius=min_radius,
486+
nseg=ncomp,
487+
)
488+
else:
489+
view["radius"] = within_branch_radiuses[0] * np.ones(ncomp)
490+
491+
# Update `.nodes`.
492+
#
493+
# 1) Delete N rows starting from start_idx
494+
number_deleted = num_previous_ncomp
495+
all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))
496+
497+
# 2) Insert M new rows at the same location
498+
df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point
499+
df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point
500+
501+
# 3) Combine the parts: before, new rows, and after
502+
all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)
503+
504+
# Override `comp_index` to just be a consecutive list.
505+
all_nodes["comp_index"] = np.arange(len(all_nodes))
506+
all_nodes = all_nodes.reset_index(drop=True)
507+
508+
# Update compartment structure arguments.
509+
nseg_per_branch = nseg_per_branch.at[branch_indices].set(ncomp)
510+
nseg = int(jnp.max(nseg_per_branch))
511+
cumsum_nseg = jnp.concatenate(
512+
[jnp.asarray([0]), jnp.cumsum(nseg_per_branch)]
513+
).astype(int)
514+
internal_node_inds = np.arange(cumsum_nseg[-1])
515+
516+
return all_nodes, nseg_per_branch, nseg, cumsum_nseg, internal_node_inds
517+
384518
def make_trainable(
385519
self,
386520
key: str,

jaxley/modules/branch.py

+105-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44
from copy import deepcopy
5+
from itertools import chain
56
from typing import Callable, Dict, List, Optional, Tuple, Union
67

78
import jax.numpy as jnp
@@ -25,6 +26,7 @@ class Branch(Module):
2526

2627
branch_params: Dict = {}
2728
branch_states: Dict = {}
29+
module_type = "branch"
2830

2931
def __init__(
3032
self,
@@ -56,7 +58,7 @@ def __init__(
5658
compartment_list = compartments
5759

5860
self.nseg = len(compartment_list)
59-
self.nseg_per_branch = [self.nseg]
61+
self.nseg_per_branch = jnp.asarray([self.nseg])
6062
self.total_nbranches = 1
6163
self.nbranches_per_cell = [1]
6264
self.cumsum_nbranches = jnp.asarray([0, 1])
@@ -146,6 +148,51 @@ def _init_morph_jax_spsolve(self):
146148
def __len__(self) -> int:
147149
return self.nseg
148150

151+
def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None):
152+
"""Set the number of compartments with which the branch is discretized.
153+
154+
Args:
155+
ncomp: The number of compartments that the branch should be discretized
156+
into.
157+
158+
Raises:
159+
- When the Module is a Network.
160+
- When there are stimuli in any compartment in the Module.
161+
- When there are recordings in any compartment in the Module.
162+
- When the channels of the compartments are not the same within the branch
163+
that is modified.
164+
- When the lengths of the compartments are not the same within the branch
165+
that is modified.
166+
- Unless the morphology was read from an SWC file, when the radiuses of the
167+
compartments are not the same within the branch that is modified.
168+
"""
169+
assert len(self.externals) == 0, "No stimuli allowed!"
170+
assert len(self.recordings) == 0, "No recordings allowed!"
171+
assert len(self.trainable_params) == 0, "No trainables allowed!"
172+
173+
# Update all attributes that are affected by compartment structure.
174+
(
175+
self.nodes,
176+
self.nseg_per_branch,
177+
self.nseg,
178+
self.cumsum_nseg,
179+
self._internal_node_inds,
180+
) = self._set_ncomp(
181+
ncomp,
182+
self.nodes,
183+
self.nodes,
184+
self.nodes["comp_index"].to_numpy()[0],
185+
self.nseg_per_branch,
186+
[c._name for c in self.channels],
187+
list(chain(*[c.channel_params for c in self.channels])),
188+
list(chain(*[c.channel_states for c in self.channels])),
189+
self._radius_generating_fns,
190+
min_radius,
191+
)
192+
193+
# Update the morphology indexing (e.g., `.comp_edges`).
194+
self.initialize()
195+
149196

150197
class BranchView(View):
151198
"""BranchView."""
@@ -167,3 +214,60 @@ def __getattr__(self, key):
167214
assert key in ["comp", "loc"]
168215
compview = CompartmentView(self.pointer, self.view)
169216
return compview if key == "comp" else compview.loc
217+
218+
def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None):
219+
"""Set the number of compartments with which the branch is discretized.
220+
221+
Args:
222+
ncomp: The number of compartments that the branch should be discretized
223+
into.
224+
min_radius: Only used if the morphology was read from an SWC file. If passed
225+
the radius is capped to be at least this value.
226+
227+
Raises:
228+
- When there are stimuli in any compartment in the module.
229+
- When there are recordings in any compartment in the module.
230+
- When the channels of the compartments are not the same within the branch
231+
that is modified.
232+
- When the lengths of the compartments are not the same within the branch
233+
that is modified.
234+
- Unless the morphology was read from an SWC file, when the radiuses of the
235+
compartments are not the same within the branch that is modified.
236+
"""
237+
if self.pointer.module_type == "network":
238+
raise NotImplementedError(
239+
"`.set_ncomp` is not yet supported for a `Network`. To overcome this, "
240+
"first build individual cells with the desired `ncomp` and then "
241+
"assemble them into a network."
242+
)
243+
244+
error_msg = lambda name: (
245+
f"Your module contains a {name}. This is not allowed. First build the "
246+
"morphology with `set_ncomp()` and then insert stimuli, recordings, and "
247+
"define trainables."
248+
)
249+
assert len(self.pointer.externals) == 0, error_msg("stimulus")
250+
assert len(self.pointer.recordings) == 0, error_msg("recording")
251+
assert len(self.pointer.trainable_params) == 0, error_msg("trainable parameter")
252+
# Update all attributes that are affected by compartment structure.
253+
(
254+
self.pointer.nodes,
255+
self.pointer.nseg_per_branch,
256+
self.pointer.nseg,
257+
self.pointer.cumsum_nseg,
258+
self.pointer._internal_node_inds,
259+
) = self.pointer._set_ncomp(
260+
ncomp,
261+
self.view,
262+
self.pointer.nodes,
263+
self.view["global_comp_index"].to_numpy()[0],
264+
self.pointer.nseg_per_branch,
265+
[c._name for c in self.pointer.channels],
266+
list(chain(*[c.channel_params for c in self.pointer.channels])),
267+
list(chain(*[c.channel_states for c in self.pointer.channels])),
268+
self.pointer._radius_generating_fns,
269+
min_radius,
270+
)
271+
272+
# Update the morphology indexing (e.g., `.comp_edges`).
273+
self.pointer.initialize()

0 commit comments

Comments
 (0)