Skip to content

Commit

Permalink
Feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Sep 23, 2024
1 parent 57d68a9 commit e1338a5
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _init_morph_jaxley_spsolve(self):
)
self.root_inds = jnp.asarray([0])

# Generate mapping to dealing with the masking which allows using the custom
# Generate mapping to deal with the masking which allows using the custom
# sparse solver to deal with different nseg per branch.
self._remapped_node_indices = remap_index_to_masked(
self._internal_node_inds,
Expand Down
4 changes: 2 additions & 2 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _init_morph_jax_spsolve(self):
self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, self.cells
):
offset_within_cell = cell.cumsum_nseg[-1]
condition = np.isin(cell._comp_edges["type"].to_numpy(), [1, 2])
condition = cell._comp_edges["type"].isin([1, 2])
rows = cell._comp_edges[condition]
self._comp_edges = pd.concat(
[
Expand All @@ -222,7 +222,7 @@ def _init_morph_jax_spsolve(self):
self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, self.cells
):
offset_within_cell = cell.cumsum_nseg[-1]
condition = np.isin(cell._comp_edges["type"].to_numpy(), [3, 4])
condition = cell._comp_edges["type"].isin([3, 4])
rows = cell._comp_edges[condition]
self._comp_edges = pd.concat(
[
Expand Down
4 changes: 2 additions & 2 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def compute_axial_conductances(
conds_c2c = jnp.asarray([])

# `branchpoint-to-compartment` (bp2c) axial coupling conductances.
condition = np.isin(comp_edges["type"].to_numpy(), [1, 2])
condition = comp_edges["type"].isin([1, 2])
sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())

if len(sink_comp_inds) > 0:
Expand All @@ -454,7 +454,7 @@ def compute_axial_conductances(
conds_bp2c = jnp.asarray([])

# `compartment-to-branchpoint` (c2bp) axial coupling conductances.
condition = np.isin(comp_edges["type"].to_numpy(), [3, 4])
condition = comp_edges["type"].isin([3, 4])
source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())

if len(source_comp_inds) > 0:
Expand Down
5 changes: 4 additions & 1 deletion jaxley/utils/solver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def remap_index_to_masked(
"""Convert actual index of the compartment to the index in the masked system.
E.g. if `nsegs = [2, 4]`, then the index `3` would be mapped to `5` because the
masked `nsegs` are `[4, 4]`.
masked `nsegs` are `[4, 4]`. I.e.:
original: [0, 1, 2, 3, 4, 5]
masked: [0, 1, (2) ,(3) ,4, 5, 6, 7]
"""
cumsum_nseg_per_branch = jnp.concatenate(
[
Expand Down

0 comments on commit e1338a5

Please sign in to comment.