diff --git a/spikeinterface_gui/__init__.py b/spikeinterface_gui/__init__.py index 90764da..f8923a7 100644 --- a/spikeinterface_gui/__init__.py +++ b/spikeinterface_gui/__init__.py @@ -12,5 +12,5 @@ from .version import version as __version__ -from .main import run_mainwindow, run_launcher +from .main import run_mainwindow, run_launcher, run_compare_analyzer diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index c5b1839..ea315bc 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -3,6 +3,7 @@ import markdown import numpy as np from copy import copy +import itertools import weakref @@ -420,3 +421,157 @@ def refresh(self): 'horizontal' : QT.Qt.Horizontal, 'vertical' : QT.Qt.Vertical, } + + +class ControllerSynchronizer(QT.QWidget): + def __init__(self, sorting_comparison, controllers, windows, names, parent=None): + QT.QWidget.__init__(self, parent=parent) + + self.comp = sorting_comparison + self.controllers = controllers + self.windows = windows + self.names = names + + self.layout = QT.QVBoxLayout() + self.setLayout(self.layout) + + self.label = QT.QLabel('') + self.layout.addWidget(self.label) + + + for i, window in enumerate(windows): + + # this is not working ???!!!!! + # callback = lambda: self.on_unit_visibility_changed(win_ind=i) + + # so uggly solution + callback = [self.on_unit_visibility_changed_0, self.on_unit_visibility_changed_1][i] + + for view in window.views.values(): + view.notifier.unit_visibility_changed.connect(callback) + + settings = [ + {'name': 'mode', 'type': 'list', 'limits' : ['best', 'all',] }, + {'name': 'thresh', 'type': 'float', 'value' : 0.3, 'step': 0.01, 'limits': (0, 1.)}, + ] + self.settings = pg.parametertree.Parameter.create(name="settings", type='group', children=settings) + + # not that the parent is not the view (not Qt anymore) itself but the widget + self.tree_settings = pg.parametertree.ParameterTree(parent=self) + self.tree_settings.header().hide() + self.tree_settings.setParameters(self.settings, showTop=True) + self.tree_settings.setWindowTitle('Settings') + self.layout.addWidget(self.tree_settings) + + from .utils_qt import ViewBoxHandlingClickToPositionWithCtrl + + self.graphicsview = pg.GraphicsView() + self.layout.addWidget(self.graphicsview) + self.viewBox = ViewBoxHandlingClickToPositionWithCtrl() + self.viewBox.clicked.connect(self._qt_select_pair) + self.viewBox.disableAutoRange() + self.plot = pg.PlotItem(viewBox=self.viewBox) + self.graphicsview.setCentralItem(self.plot) + self.plot.hideButtons() + self.image = pg.ImageItem() + self.plot.addItem(self.image) + self.plot.hideAxis('bottom') + self.plot.hideAxis('left') + + + import matplotlib + N = 512 + cmap_name = 'viridis' + cmap = matplotlib.colormaps[cmap_name].resampled(N) + lut = [] + for i in range(N): + r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) + lut.append([r*255,g*255,b*255]) + self.lut = np.array(lut, dtype='uint8') + + # agreement = self.comp.agreement_scores.values + self.agreement_ordered = self.comp.get_ordered_agreement_scores() + self.image.setImage(self.agreement_ordered.values , lut=self.lut, levels=[0, 1]) + self.image.show() + self.plot.setXRange(0, self.agreement_ordered.shape[0]) + self.plot.setLabel('bottom', names[0]) + self.plot.setYRange(0, self.agreement_ordered.shape[1]) + self.plot.setLabel('left', names[1]) + + + + def on_unit_visibility_changed_0(self): + self.on_unit_visibility_changed(0) + + def on_unit_visibility_changed_1(self): + self.on_unit_visibility_changed(1) + + + def on_unit_visibility_changed(self, win_ind): + changed_controller = self.controllers[win_ind] + visible_unit_inds = changed_controller.get_visible_unit_indices() + visible_unit_ids = changed_controller.get_visible_unit_ids() + if len(visible_unit_inds) != 1: + # TODO handle several units at once + return + + unit_ind = visible_unit_inds[0] + + agreement = self.comp.agreement_scores.values + if win_ind == 1: + agreement = agreement.T + + thresh = self.settings['thresh'] + mode = self.settings['mode'] + + other_ind = (win_ind + 1) % 2 + other_controller = self.controllers[other_ind] + other_window = self.windows[other_ind] + + if mode == 'all': + other_visible_inds = agreement[unit_ind, :] > thresh + elif mode == 'best': + best_ind = np.argmax(agreement[unit_ind, :]) + if agreement[unit_ind, best_ind] > thresh: + other_visible_inds = [best_ind] + else: + other_visible_inds = [] + + other_visible_ids = other_controller.unit_ids[other_visible_inds] + other_controller.set_visible_unit_ids(other_visible_ids) + + for view in other_window.views.values(): + view.refresh() + + + self._refresh_label() + + def _refresh_label(self): + + txt = '' + unit_ids0 = self.controllers[0].get_visible_unit_ids() + unit_ids1 = self.controllers[1].get_visible_unit_ids() + for unit_id0, unit_id1 in itertools.product(unit_ids0, unit_ids1): + a = self.comp.agreement_scores.loc[unit_id0, unit_id1] + txt += f'{self.names[0]} unit {unit_id0} - {self.names[1]} unit {unit_id1} agreement={a}\n' + self.label.setText(txt) + + def _qt_select_pair(self, x, y, reset): + c0 = self.controllers[0] + c1 = self.controllers[1] + + # used + ordered_unit_ids1 = self.agreement_ordered.index + ordered_unit_ids2 = self.agreement_ordered.columns + unit_id0 = ordered_unit_ids1[int(np.floor(x))] + unit_id1 = ordered_unit_ids2[int(np.floor(y))] + + c0.set_visible_unit_ids([unit_id0]) + c1.set_visible_unit_ids([unit_id1]) + + for win in self.windows: + for view in win.views.values(): + view.refresh() + + self._refresh_label() + diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 19fe9b4..67878be 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -377,4 +377,82 @@ def find_skippable_extensions(layout_dict): skippable_extensions = list(all_extensions.difference(set(needed_extensions))) - return skippable_extensions \ No newline at end of file + return skippable_extensions + + +def run_compare_analyzer( + analyzers, + mode="desktop", + with_traces=False, + # displayed_unit_properties=None, + skip_extensions=None, + layout_preset=None, + layout=None, + verbose=False, + names=None, + # user_settings=None, + # disable_save_settings_button=False, +): + + assert isinstance(analyzers, list) + assert len(analyzers) == 2 + assert mode == "desktop" + + from spikeinterface_gui.myqt import QT, mkQApp + from spikeinterface_gui.backend_qt import QtMainWindow, ControllerSynchronizer + from spikeinterface.comparison import compare_two_sorters + + app = mkQApp() + + + layout_dict = get_layout_description(layout_preset, layout) + + if names is None: + names = [f'Analyzer {i}' for i in range(2)] + + controllers = [] + windows = [] + for i, analyzer in enumerate(analyzers): + + if verbose: + import time + t0 = time.perf_counter() + + controller = Controller( + analyzer, backend="qt", + # verbose=verbose, + verbose=False, + + curation=False, + with_traces=with_traces, + skip_extensions=skip_extensions, + ) + + if verbose: + t1 = time.perf_counter() + print('controller init time', t1 - t0) + + + # Suppress a known pyqtgraph warning + warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") + warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*") + + + + win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings) + name = names[i] + win.setWindowTitle(name) + # Set window icon + icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' + if icon_file.exists(): + app.setWindowIcon(QT.QIcon(str(icon_file))) + win.show() + windows.append(win) + controllers.append(controller) + + comp = compare_two_sorters(analyzers[0].sorting, analyzers[1].sorting) + + synchronizer = ControllerSynchronizer(comp, controllers, windows, names) + synchronizer.show() + + app.exec() \ No newline at end of file diff --git a/spikeinterface_gui/tests/test_compare_qt.py b/spikeinterface_gui/tests/test_compare_qt.py new file mode 100644 index 0000000..d3783e9 --- /dev/null +++ b/spikeinterface_gui/tests/test_compare_qt.py @@ -0,0 +1,43 @@ +from argparse import ArgumentParser +from spikeinterface_gui import run_compare_analyzer + +from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict + +from spikeinterface import load_sorting_analyzer + + +from pathlib import Path + +import numpy as np +import sys + + + + +def setup_module(): + global test_folder + case = test_folder.stem.split('_')[-1] + make_analyzer_folder(test_folder, case=case) + +def teardown_module(): + clean_all(test_folder) + + +def test_run_compare_analyzer(): + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") + analyzers = [analyzer, analyzer] + run_compare_analyzer( + analyzers, + mode="desktop", + verbose=True, + ) + +if __name__ == '__main__': + global test_folder + + dataset = "small" + test_folder = Path(dataset).parent / f"my_dataset_{dataset}" + if not test_folder.is_dir(): + setup_module() + + win = test_run_compare_analyzer()