diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index d8aa513b..169fca46 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -179,7 +179,7 @@ def _init_morph_jaxley_spsolve(self): ) self.root_inds = jnp.asarray([0]) - # Generate mapping to dealing with the masking which allows using the custom + # Generate mapping to deal with the masking which allows using the custom # sparse solver to deal with different nseg per branch. self._remapped_node_indices = remap_index_to_masked( self._internal_node_inds, diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 707777ef..7e356016 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -202,7 +202,7 @@ def _init_morph_jax_spsolve(self): self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, self.cells ): offset_within_cell = cell.cumsum_nseg[-1] - condition = np.isin(cell._comp_edges["type"].to_numpy(), [1, 2]) + condition = cell._comp_edges["type"].isin([1, 2]) rows = cell._comp_edges[condition] self._comp_edges = pd.concat( [ @@ -222,7 +222,7 @@ def _init_morph_jax_spsolve(self): self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, self.cells ): offset_within_cell = cell.cumsum_nseg[-1] - condition = np.isin(cell._comp_edges["type"].to_numpy(), [3, 4]) + condition = cell._comp_edges["type"].isin([3, 4]) rows = cell._comp_edges[condition] self._comp_edges = pd.concat( [ diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index ceac97d8..ffa8c0e3 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -438,7 +438,7 @@ def compute_axial_conductances( conds_c2c = jnp.asarray([]) # `branchpoint-to-compartment` (bp2c) axial coupling conductances. - condition = np.isin(comp_edges["type"].to_numpy(), [1, 2]) + condition = comp_edges["type"].isin([1, 2]) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) if len(sink_comp_inds) > 0: @@ -454,7 +454,7 @@ def compute_axial_conductances( conds_bp2c = jnp.asarray([]) # `compartment-to-branchpoint` (c2bp) axial coupling conductances. - condition = np.isin(comp_edges["type"].to_numpy(), [3, 4]) + condition = comp_edges["type"].isin([3, 4]) source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) if len(source_comp_inds) > 0: diff --git a/jaxley/utils/solver_utils.py b/jaxley/utils/solver_utils.py index 36759e42..50d161cb 100644 --- a/jaxley/utils/solver_utils.py +++ b/jaxley/utils/solver_utils.py @@ -14,7 +14,10 @@ def remap_index_to_masked( """Convert actual index of the compartment to the index in the masked system. E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the - masked `nsegs` are `[4, 4]`. + masked `nsegs` are `[4, 4]`. I.e.: + + original: [0, 1, 2, 3, 4, 5] + masked: [0, 1, (2) ,(3) ,4, 5, 6, 7] """ cumsum_nseg_per_branch = jnp.concatenate( [