Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add uninsert method #521

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,28 @@ def insert(self, channel: Channel):
for key in channel.channel_states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]

def delete_channel(self, channel: Channel):
"""Remove a channel from the module.

Args:
channel: The channel to remove."""
name = channel._name
channel_names = [c._name for c in self.channels]
all_channel_names = [c._name for c in self.base.channels]
if name in channel_names:
channel_cols = list(channel.channel_params.keys())
channel_cols += list(channel.channel_states.keys())
self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan")
self.base.nodes.loc[self._nodes_in_view, name] = False

# only delete cols if no other comps in the module have the same channel
if np.all(~self.base.nodes[name]):
self.base.channels.pop(all_channel_names.index(name))
self.base.membrane_current_names.remove(channel.current_name)
self.base.nodes.drop(columns=channel_cols + [name], inplace=True)
else:
raise ValueError(f"Channel {name} not found in the module.")

@only_allow_module
def step(
self,
Expand Down
60 changes: 60 additions & 0 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,63 @@ def compute_current(self, states, v, params):
num_channels = 2
target = (t_max // dt + 2) * 0.001 * 0.01 * num_channels
assert np.abs(target - s[0, -1]) < 1e-8


def test_delete_channel(SimpleBranch):
# test complete removal of a channel from a module
branch1 = SimpleBranch(nseg=3)
branch1.comp(0).insert(K())
branch1.delete_channel(K())

branch2 = SimpleBranch(nseg=3)
branch2.comp(0).insert(K())
branch2.comp(0).delete_channel(K())

branch3 = SimpleBranch(nseg=3)
branch3.insert(K())
branch3.delete_channel(K())

def channel_present(view, channel, partial=False):
states_and_params = list(channel.channel_states.keys()) + list(
channel.channel_params.keys()
)
# none of the states or params should be in nodes
cols = view.nodes.columns.to_list()
channel_cols = [
col
for col in cols
if col.startswith(channel._name) and col != channel._name
]
diff = set(channel_cols).difference(set(states_and_params))
has_params_or_states = len(diff) > 0
has_channel_col = channel._name in view.nodes.columns
has_channel = channel._name in [c._name for c in view.channels]
has_mem_current = channel.current_name in view.membrane_current_names
if partial:
all_nans = (
not view.nodes[channel_cols].isna().all().all()
& ~view.nodes[channel._name].all()
)
return has_channel or has_mem_current or all_nans
return has_params_or_states or has_channel_col or has_channel or has_mem_current

for branch in [branch1, branch2, branch3]:
assert len(branch.channels) == 0
assert not channel_present(branch, K())

# test correct channels are removed only in the viewed part of the module
branch4 = SimpleBranch(nseg=3)
branch4.insert(HH())
branch4.comp(0).insert(K())
branch4.comp([1, 2]).insert(Leak())

branch4.comp(1).delete_channel(Leak())
# assert K in comp 0 and Leak still present in branch
assert channel_present(branch4.comp(0), K())
assert channel_present(branch4.comp(2), Leak(), partial=True)
assert not channel_present(branch4.comp(1), Leak(), partial=True)
assert channel_present(branch4, Leak())

branch4.comp(2).delete_channel(Leak())
# assert no more Leak
assert not channel_present(branch4, Leak())
Loading