Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion spikeinterface_gui/backend_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SignalNotifier(QT.QObject):
channel_visibility_changed = QT.pyqtSignal()
manual_curation_updated = QT.pyqtSignal()
time_info_updated = QT.pyqtSignal()
use_times_updated = QT.pyqtSignal()
unit_color_changed = QT.pyqtSignal()

def __init__(self, parent=None, view=None):
Expand All @@ -40,6 +41,9 @@ def notify_manual_curation_updated(self):
def notify_time_info_updated(self):
self.time_info_updated.emit()

def notify_use_times_updated(self):
self.use_times_updated.emit()

def notify_unit_color_changed(self):
self.unit_color_changed.emit()

Expand All @@ -63,6 +67,7 @@ def connect_view(self, view):
view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed)
view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated)
view.notifier.time_info_updated.connect(self.on_time_info_updated)
view.notifier.use_times_updated.connect(self.on_use_times_updated)
view.notifier.unit_color_changed.connect(self.on_unit_color_changed)

def on_spike_selection_changed(self):
Expand Down Expand Up @@ -110,7 +115,16 @@ def on_time_info_updated(self):
# do not refresh it self
continue
view.on_time_info_updated()


def on_use_times_updated(self):
if not self._active:
return
for view in self.controller.views:
if view.qt_widget == self.sender().parent():
# do not refresh it self
continue
view.on_use_times_updated()

def on_unit_color_changed(self):
if not self._active:
return
Expand Down
73 changes: 48 additions & 25 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)


def get_unit_data(self, unit_id, seg_index=0):
inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index)
spike_times = self.controller.spikes["sample_index"][inds] / self.controller.sampling_frequency
def get_unit_data(self, unit_id, segment_index=0):
inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index)
spike_indices = self.controller.spikes["sample_index"][inds]
spike_times = self.controller.sample_index_to_time(spike_indices)
spike_data = self.spike_data[inds]
ptp = np.ptp(spike_data)
hist_min, hist_max = [np.min(spike_data) - 0.2 * ptp, np.max(spike_data) + 0.2 * ptp]
Expand All @@ -53,8 +54,8 @@ def get_unit_data(self, unit_id, seg_index=0):

return spike_times, spike_data, hist_count, hist_bins, inds

def get_selected_spikes_data(self, seg_index=0, visible_inds=None):
sl = self.controller.segment_slices[seg_index]
def get_selected_spikes_data(self, segment_index=0, visible_inds=None):
sl = self.controller.segment_slices[segment_index]
spikes_in_seg = self.controller.spikes[sl]
selected_indices = self.controller.get_indices_spike_selected()
if visible_inds is not None:
Expand Down Expand Up @@ -85,7 +86,7 @@ def select_all_spikes_from_lasso(self, keep_already_selected=False):
for segment_index, vertices in self._lasso_vertices.items():
if vertices is None:
continue
spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index)
spike_inds = self.controller.get_spike_indices(visible_unit_id, segment_index=segment_index)
spike_times = self.controller.spikes["sample_index"][spike_inds] / fs
spike_data = self.spike_data[spike_inds]

Expand Down Expand Up @@ -119,7 +120,7 @@ def split(self):

if self.controller.num_segments > 1:
# check that lasso vertices are defined for all segments
if not all(self._lasso_vertices[seg_index] is not None for seg_index in range(self.controller.num_segments)):
if not all(self._lasso_vertices[segment_index] is not None for segment_index in range(self.controller.num_segments)):
# Use the new continue_from_user pattern
self.continue_from_user(
"Not all segments have lasso selection. "
Expand Down Expand Up @@ -163,6 +164,12 @@ def on_unit_visibility_changed(self):
self._current_selected = self.controller.get_indices_spike_selected().size
self.refresh()

def on_time_info_updated(self):
return self.refresh()

def on_use_times_updated(self):
return self.refresh()

## QT zone ##
def _qt_make_layout(self):
from .myqt import QT
Expand All @@ -174,8 +181,8 @@ def _qt_make_layout(self):
tb = self.qt_widget.view_toolbar
self.combo_seg = QT.QComboBox()
tb.addWidget(self.combo_seg)
self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ])
self.combo_seg.currentIndexChanged.connect(self.refresh)
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
add_stretch_to_qtoolbar(tb)
self.lasso_but = QT.QPushButton("select", checkable = True)
tb.addWidget(self.lasso_but)
Expand Down Expand Up @@ -235,6 +242,12 @@ def _qt_initialize_plot(self):
def _qt_on_spike_selection_changed(self):
self.refresh()

def _qt_change_segment(self):
segment_index = self.combo_seg.currentIndex()
self.controller.set_time(segment_index=segment_index)
self.refresh()
self.notify_time_info_updated()

def _qt_refresh(self):
from .myqt import QT
import pyqtgraph as pg
Expand All @@ -246,13 +259,18 @@ def _qt_refresh(self):
if self.spike_data is None:
return

segment_index = self.controller.get_time()[1]
# Update combo_seg if it doesn't match the current segment index
if self.combo_seg.currentIndex() != segment_index:
self.combo_seg.setCurrentIndex(segment_index)

max_count = 1
all_inds = []
for unit_id in self.controller.get_visible_unit_ids():

spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
unit_id,
seg_index=self.combo_seg.currentIndex()
segment_index=segment_index
)

# make a copy of the color
Expand All @@ -276,7 +294,7 @@ def _qt_refresh(self):
y_range_plot_1 = self.plot.getViewBox().viewRange()
self.viewBox2.setYRange(y_range_plot_1[1][0], y_range_plot_1[1][1], padding = 0.0)

spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds)

self.scatter_select.setData(spike_times, spike_data)

Expand All @@ -296,8 +314,8 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
self.lasso.setData([], [])
vertices = np.array(points)

seg_index = self.combo_seg.currentIndex()
sl = self.controller.segment_slices[seg_index]
segment_index = self.combo_seg.currentIndex()
sl = self.controller.segment_slices[segment_index]
spikes_in_seg = self.controller.spikes[sl]

# Create mask for visible units
Expand All @@ -315,16 +333,16 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
self.notify_spike_selection_changed()
return

if self._lasso_vertices[seg_index] is None:
self._lasso_vertices[seg_index] = []
if self._lasso_vertices[segment_index] is None:
self._lasso_vertices[segment_index] = []

if shift_held:
# If shift is held, append the vertices to the current lasso vertices
self._lasso_vertices[seg_index].append(vertices)
self._lasso_vertices[segment_index].append(vertices)
keep_already_selected = True
else:
# If shift is not held, clear the existing lasso vertices for this segment
self._lasso_vertices[seg_index] = [vertices]
self._lasso_vertices[segment_index] = [vertices]
keep_already_selected = False

self.select_all_spikes_from_lasso(keep_already_selected=keep_already_selected)
Expand Down Expand Up @@ -445,11 +463,13 @@ def _panel_refresh(self):
ys = []
colors = []

segment_index = self.controller.get_time()[1]

visible_unit_ids = self.controller.get_visible_unit_ids()
for unit_id in visible_unit_ids:
spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
unit_id,
seg_index=self.segment_index
segment_index=segment_index
)
color = self.get_unit_color(unit_id)
xs.extend(spike_times)
Expand Down Expand Up @@ -504,9 +524,12 @@ def _panel_on_select_button(self, event):
def _panel_change_segment(self, event):
self._current_selected = 0
self.segment_index = int(self.segment_selector.value.split()[-1])
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
self.scatter_fig.x_range.end = time_max
self.controller.set_time(segment_index=self.segment_index)
t_start, t_end = self.controller.get_t_start_t_end()
self.scatter_fig.x_range.start = t_start
self.scatter_fig.x_range.end = t_end
self.refresh()
self.notify_time_info_updated()

def _on_panel_selection_geometry(self, event):
"""
Expand All @@ -524,16 +547,16 @@ def _on_panel_selection_geometry(self, event):
return

# Append the current polygon to the lasso vertices if shift is held
seg_index = self.segment_index
if self._lasso_vertices[seg_index] is None:
self._lasso_vertices[seg_index] = []
segment_index = self.segment_index
if self._lasso_vertices[segment_index] is None:
self._lasso_vertices[segment_index] = []
if len(selected) > self._current_selected:
self._current_selected = len(selected)
# Store the current polygon for the current segment
self._lasso_vertices[seg_index].append(polygon)
self._lasso_vertices[segment_index].append(polygon)
keep_already_selected = True
else:
self._lasso_vertices[seg_index] = [polygon]
self._lasso_vertices[segment_index] = [polygon]
keep_already_selected = False

self.select_all_spikes_from_lasso(keep_already_selected)
Expand Down
70 changes: 55 additions & 15 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_default_main_settings = dict(
max_visible_units=10,
color_mode='color_by_unit',
use_times=False
)

from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties
Expand Down Expand Up @@ -264,7 +265,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save

# self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict")
seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1))
self.segment_slices = {seg_index: slice(seg_limits[seg_index], seg_limits[seg_index + 1]) for seg_index in range(num_seg)}
self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)}

spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2]
Expand All @@ -275,7 +276,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
spike_per_seg = [s.size for s in spike_vector2]
# dict[unit_id] -> all indices for this unit across segments
self._spike_index_by_units = {}
# dict[seg_index][unit_id] -> all indices for this unit for one segment
# dict[segment_index][unit_id] -> all indices for this unit for one segment
self._spike_index_by_segment_and_units = spike_indices_abs
for unit_id in unit_ids:
inds = []
Expand All @@ -302,10 +303,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
self.displayed_unit_properties = displayed_unit_properties

# set default time info
self.time_info = dict(
time_by_seg=np.array([0] * self.num_segments, dtype="float64"),
segment_index=0
)
self.update_time_info()

self.curation = curation
# TODO: Reload the dictionary if it already exists
Expand Down Expand Up @@ -401,10 +399,10 @@ def get_time(self):
"""
Returns selected time and segment index
"""
seg_index = self.time_info['segment_index']
segment_index = self.time_info['segment_index']
time_by_seg = self.time_info['time_by_seg']
time = time_by_seg[seg_index]
return time, seg_index
time = time_by_seg[segment_index]
return time, segment_index

def set_time(self, time=None, segment_index=None):
"""
Expand All @@ -418,7 +416,49 @@ def set_time(self, time=None, segment_index=None):
segment_index = self.time_info['segment_index']
if time is not None:
self.time_info['time_by_seg'][segment_index] = time


def update_time_info(self):
# set default time info
if self.main_settings["use_times"] and self.analyzer.has_recording():
self.time_info = dict(
time_by_seg=np.array(
[
self.analyzer.recording.get_start_time(segment_index) for segment_index in range(self.num_segments)
],
dtype="float64"),
segment_index=0
)
else:
self.time_info = dict(
time_by_seg=np.array([0] * self.num_segments, dtype="float64"),
segment_index=0
)

def get_t_start_t_stop(self):
segment_index = self.time_info["segment_index"]
if self.main_settings["use_times"] and self.analyzer.has_recording():
t_start = self.analyzer.recording.get_start_time(segment_index=segment_index)
t_stop = self.analyzer.recording.get_end_time(segment_index=segment_index)
return t_start, t_stop
else:
return 0, self.get_num_samples(segment_index) / self.sampling_frequency

def sample_index_to_time(self, sample_index):
segment_index = self.time_info["segment_index"]
if self.main_settings["use_times"] and self.analyzer.has_recording():
time = self.analyzer.recording.sample_index_to_time(sample_index, segment_index=segment_index)
return time
else:
return sample_index / self.sampling_frequency

def time_to_sample_index(self, time):
segment_index = self.time_info["segment_index"]
if self.main_settings["use_times"] and self.analyzer.has_recording():
time = self.analyzer.recording.time_to_sample_index(time, segment_index=segment_index)
return time
else:
return int(time * self.sampling_frequency)

def get_information_txt(self):
nseg = self.analyzer.get_num_segments()
nchan = self.analyzer.get_num_channels()
Expand Down Expand Up @@ -552,13 +592,13 @@ def set_indices_spike_selected(self, inds):
sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]]
self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index)

def get_spike_indices(self, unit_id, seg_index=None):
if seg_index is None:
def get_spike_indices(self, unit_id, segment_index=None):
if segment_index is None:
# dict[unit_id] -> all indices for this unit across segments
return self._spike_index_by_units[unit_id]
else:
# dict[seg_index][unit_id] -> all indices for this unit for one segment
return self._spike_index_by_segment_and_units[seg_index][unit_id]
# dict[segment_index][unit_id] -> all indices for this unit for one segment
return self._spike_index_by_segment_and_units[segment_index][unit_id]

def get_num_samples(self, segment_index):
return self.analyzer.get_num_samples(segment_index=segment_index)
Expand Down Expand Up @@ -838,7 +878,7 @@ def make_manual_split_if_possible(self, unit_id):
indices = self.get_indices_spike_selected()
if len(indices) == 0:
return False
spike_inds = self.get_spike_indices(unit_id, seg_index=None)
spike_inds = self.get_spike_indices(unit_id, segment_index=None)
if not np.all(np.isin(indices, spike_inds)):
return False

Expand Down
2 changes: 1 addition & 1 deletion spikeinterface_gui/curationview.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def unsplit(self):
def select_and_notify_split(self, split_unit_id):
self.controller.set_visible_unit_ids([split_unit_id])
self.notify_unit_visibility_changed()
spike_inds = self.controller.get_spike_indices(split_unit_id, seg_index=None)
spike_inds = self.controller.get_spike_indices(split_unit_id, segment_index=None)
active_split = [s for s in self.controller.curation_data['splits'] if s['unit_id'] == split_unit_id][0]
split_indices = active_split['indices'][0]
self.controller.set_indices_spike_selected(spike_inds[split_indices])
Expand Down
Loading