|
5 | 5 | from abc import ABC, abstractmethod
|
6 | 6 | from copy import deepcopy
|
7 | 7 | from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 8 | +from warnings import warn |
8 | 9 |
|
9 | 10 | import jax.numpy as jnp
|
10 | 11 | import numpy as np
|
|
35 | 36 | from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
|
36 | 37 | from jaxley.utils.plot_utils import plot_morph
|
37 | 38 | from jaxley.utils.solver_utils import convert_to_csc
|
| 39 | +from jaxley.utils.swc import build_radiuses_from_xyzr |
38 | 40 |
|
39 | 41 |
|
40 | 42 | class Module(ABC):
|
@@ -109,6 +111,7 @@ def __init__(self):
|
109 | 111 |
|
110 | 112 | # x, y, z coordinates and radius.
|
111 | 113 | self.xyzr: List[np.ndarray] = []
|
| 114 | + self._radius_generating_fns = None # Defined by `.read_swc()`. |
112 | 115 |
|
113 | 116 | # For debugging the solver. Will be empty by default and only filled if
|
114 | 117 | # `self._init_morph_for_debugging` is run.
|
@@ -381,6 +384,137 @@ def _data_set(
|
381 | 384 | raise KeyError("Key not recognized.")
|
382 | 385 | return param_state
|
383 | 386 |
|
| 387 | + def _set_ncomp( |
| 388 | + self, |
| 389 | + ncomp: int, |
| 390 | + view: pd.DataFrame, |
| 391 | + all_nodes: pd.DataFrame, |
| 392 | + start_idx: int, |
| 393 | + nseg_per_branch: jnp.asarray, |
| 394 | + channel_names: List[str], |
| 395 | + channel_param_names: List[str], |
| 396 | + channel_state_names: List[str], |
| 397 | + radius_generating_fns: List[Callable], |
| 398 | + min_radius: Optional[float], |
| 399 | + ): |
| 400 | + """Set the number of compartments with which the branch is discretized.""" |
| 401 | + within_branch_radiuses = view["radius"].to_numpy() |
| 402 | + compartment_lengths = view["length"].to_numpy() |
| 403 | + num_previous_ncomp = len(within_branch_radiuses) |
| 404 | + branch_indices = pd.unique(view["branch_index"]) |
| 405 | + |
| 406 | + error_msg = lambda name: ( |
| 407 | + f"You previously modified the {name} of individual compartments, but " |
| 408 | + f"now you are modifying the number of compartments in this branch. " |
| 409 | + f"This is not allowed. First build the morphology with `set_ncomp()` and " |
| 410 | + f"then modify the radiuses and lengths of compartments." |
| 411 | + ) |
| 412 | + |
| 413 | + if ( |
| 414 | + ~np.all(within_branch_radiuses == within_branch_radiuses[0]) |
| 415 | + and radius_generating_fns is None |
| 416 | + ): |
| 417 | + raise ValueError(error_msg("radius")) |
| 418 | + |
| 419 | + for property_name in ["length", "capacitance", "axial_resistivity"]: |
| 420 | + compartment_properties = view[property_name].to_numpy() |
| 421 | + if ~np.all(compartment_properties == compartment_properties[0]): |
| 422 | + raise ValueError(error_msg(property_name)) |
| 423 | + |
| 424 | + if not (view[channel_names].var() == 0.0).all(): |
| 425 | + raise ValueError( |
| 426 | + "Some channel exists only in some compartments of the branch which you" |
| 427 | + "are trying to modify. This is not allowed. First specify the number" |
| 428 | + "of compartments with `.set_ncomp()` and then insert the channels" |
| 429 | + "accordingly." |
| 430 | + ) |
| 431 | + |
| 432 | + if not (view[channel_param_names + channel_state_names].var() == 0.0).all(): |
| 433 | + raise ValueError( |
| 434 | + "Some channel has different parameters or states between the " |
| 435 | + "different compartments of the branch which you are trying to modify. " |
| 436 | + "This is not allowed. First specify the number of compartments with " |
| 437 | + "`.set_ncomp()` and then insert the channels accordingly." |
| 438 | + ) |
| 439 | + |
| 440 | + # Add new rows as the average of all rows. Special case for the length is below. |
| 441 | + average_row = view.mean(skipna=False) |
| 442 | + average_row = average_row.to_frame().T |
| 443 | + view = pd.concat([*[average_row] * ncomp], axis="rows") |
| 444 | + |
| 445 | + # If the `view` is not the entire `Module`, but a `View` (i.e. if one changes |
| 446 | + # the number of comps within a branch of a cell), then the `self.pointer.view` |
| 447 | + # will contain the additional `global_xyz_index` columns. However, the |
| 448 | + # `self.nodes` will not have these columns. |
| 449 | + # |
| 450 | + # Note that we assert that there are no trainables, so `controlled_by_params` |
| 451 | + # of the `self.nodes` has to be empty. |
| 452 | + if "global_comp_index" in view.columns: |
| 453 | + view = view.drop( |
| 454 | + columns=[ |
| 455 | + "global_comp_index", |
| 456 | + "global_branch_index", |
| 457 | + "global_cell_index", |
| 458 | + "controlled_by_param", |
| 459 | + ] |
| 460 | + ) |
| 461 | + |
| 462 | + # Set the correct datatype after having performed an average which cast |
| 463 | + # everything to float. |
| 464 | + integer_cols = ["comp_index", "branch_index", "cell_index"] |
| 465 | + view[integer_cols] = view[integer_cols].astype(int) |
| 466 | + |
| 467 | + # Whether or not a channel exists in a compartment is a boolean. |
| 468 | + boolean_cols = channel_names |
| 469 | + view[boolean_cols] = view[boolean_cols].astype(bool) |
| 470 | + |
| 471 | + # Special treatment for the lengths and radiuses. These are not being set as |
| 472 | + # the average because we: |
| 473 | + # 1) Want to maintain the total length of a branch. |
| 474 | + # 2) Want to use the SWC inferred radius. |
| 475 | + # |
| 476 | + # Compute new compartment lengths. |
| 477 | + comp_lengths = np.sum(compartment_lengths) / ncomp |
| 478 | + view["length"] = comp_lengths |
| 479 | + |
| 480 | + # Compute new compartment radiuses. |
| 481 | + if radius_generating_fns is not None: |
| 482 | + view["radius"] = build_radiuses_from_xyzr( |
| 483 | + radius_fns=radius_generating_fns, |
| 484 | + branch_indices=branch_indices, |
| 485 | + min_radius=min_radius, |
| 486 | + nseg=ncomp, |
| 487 | + ) |
| 488 | + else: |
| 489 | + view["radius"] = within_branch_radiuses[0] * np.ones(ncomp) |
| 490 | + |
| 491 | + # Update `.nodes`. |
| 492 | + # |
| 493 | + # 1) Delete N rows starting from start_idx |
| 494 | + number_deleted = num_previous_ncomp |
| 495 | + all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted)) |
| 496 | + |
| 497 | + # 2) Insert M new rows at the same location |
| 498 | + df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point |
| 499 | + df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point |
| 500 | + |
| 501 | + # 3) Combine the parts: before, new rows, and after |
| 502 | + all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True) |
| 503 | + |
| 504 | + # Override `comp_index` to just be a consecutive list. |
| 505 | + all_nodes["comp_index"] = np.arange(len(all_nodes)) |
| 506 | + all_nodes = all_nodes.reset_index(drop=True) |
| 507 | + |
| 508 | + # Update compartment structure arguments. |
| 509 | + nseg_per_branch = nseg_per_branch.at[branch_indices].set(ncomp) |
| 510 | + nseg = int(jnp.max(nseg_per_branch)) |
| 511 | + cumsum_nseg = jnp.concatenate( |
| 512 | + [jnp.asarray([0]), jnp.cumsum(nseg_per_branch)] |
| 513 | + ).astype(int) |
| 514 | + internal_node_inds = np.arange(cumsum_nseg[-1]) |
| 515 | + |
| 516 | + return all_nodes, nseg_per_branch, nseg, cumsum_nseg, internal_node_inds |
| 517 | + |
384 | 518 | def make_trainable(
|
385 | 519 | self,
|
386 | 520 | key: str,
|
|
0 commit comments