From 0c650ed8c6a63a680b62e49af4c873f2eb02feb0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Nov 2025 17:30:00 +0100 Subject: [PATCH 1/6] Add run_compare_analyzer() --- spikeinterface_gui/__init__.py | 2 +- spikeinterface_gui/backend_qt.py | 90 ++++++++++++++++++++++++++++++++ spikeinterface_gui/main.py | 73 +++++++++++++++++++++++++- 3 files changed, 163 insertions(+), 2 deletions(-) diff --git a/spikeinterface_gui/__init__.py b/spikeinterface_gui/__init__.py index 90764da..f8923a7 100644 --- a/spikeinterface_gui/__init__.py +++ b/spikeinterface_gui/__init__.py @@ -12,5 +12,5 @@ from .version import version as __version__ -from .main import run_mainwindow, run_launcher +from .main import run_mainwindow, run_launcher, run_compare_analyzer diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index c5b1839..830111a 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -420,3 +420,93 @@ def refresh(self): 'horizontal' : QT.Qt.Horizontal, 'vertical' : QT.Qt.Vertical, } + + +class ControllerSynchronizer(QT.QWidget): + def __init__(self, sorting_comparison, controllers, windows, parent=None): + QT.QWidget.__init__(self, parent=parent) + + self.comp = sorting_comparison + self.controllers = controllers + self.windows = windows + + self.layout = QT.QVBoxLayout() + self.setLayout(self.layout) + + self.label = QT.QLabel('') + self.layout.addWidget(self.label) + + + for i, window in enumerate(windows): + + # this is not working ???!!!!! + # callback = lambda: self.on_unit_visibility_changed(win_ind=i) + + # so uggly solution + callback = [self.on_unit_visibility_changed_0, self.on_unit_visibility_changed_1][i] + + for view in window.views.values(): + view.notifier.unit_visibility_changed.connect(callback) + + settings = [ + {'name': 'mode', 'type': 'list', 'limits' : ['all', 'best', ] }, + {'name': 'thresh', 'type': 'float', 'value' : 0.05, 'step': 0.01, 'limits': (0, 1.)}, + ] + self.settings = pg.parametertree.Parameter.create(name="settings", type='group', children=settings) + + # not that the parent is not the view (not Qt anymore) itself but the widget + self.tree_settings = pg.parametertree.ParameterTree(parent=self) + self.tree_settings.header().hide() + self.tree_settings.setParameters(self.settings, showTop=True) + self.tree_settings.setWindowTitle('Settings') + self.layout.addWidget(self.tree_settings) + + + def on_unit_visibility_changed_0(self): + self.on_unit_visibility_changed(0) + + def on_unit_visibility_changed_1(self): + self.on_unit_visibility_changed(1) + + + def on_unit_visibility_changed(self, win_ind): + changed_controller = self.controllers[win_ind] + visible_unit_inds = changed_controller.get_visible_unit_indices() + visible_unit_ids = changed_controller.get_visible_unit_ids() + if len(visible_unit_inds) != 1: + # TODO handle several units at once + return + + unit_ind = visible_unit_inds[0] + + agreement = self.comp.agreement_scores.values + if win_ind == 1: + agreement = agreement.T + + thresh = self.settings['thresh'] + mode = self.settings['mode'] + + other_ind = (win_ind + 1) % 2 + other_controller = self.controllers[other_ind] + other_window = self.windows[other_ind] + + if mode == 'all': + other_visible_inds = agreement[unit_ind, :] > thresh + elif mode == 'best': + best_ind = np.argmax(agreement[unit_ind, :]) + if agreement[unit_ind, best_ind] > thresh: + other_visible_inds = [best_ind] + else: + other_visible_inds = [] + + other_visible_ids = other_controller.unit_ids[other_visible_inds] + other_controller.set_visible_unit_ids(other_visible_ids) + + for view in other_window.views.values(): + view.refresh() + + self.label.setText( + f'Analyzer {win_ind} : {visible_unit_ids}\n' + f'Analyzer {other_ind} : {other_visible_ids}\n' + + ) \ No newline at end of file diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 19fe9b4..802ead5 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -377,4 +377,75 @@ def find_skippable_extensions(layout_dict): skippable_extensions = list(all_extensions.difference(set(needed_extensions))) - return skippable_extensions \ No newline at end of file + return skippable_extensions + + +def run_compare_analyzer( + analyzers, + mode="desktop", + with_traces=False, + # displayed_unit_properties=None, + skip_extensions=None, + layout_preset=None, + layout=None, + verbose=False, + # user_settings=None, + # disable_save_settings_button=False, +): + + assert isinstance(analyzers, list) + assert len(analyzers) == 2 + assert mode == "desktop" + + from spikeinterface_gui.myqt import QT, mkQApp + from spikeinterface_gui.backend_qt import QtMainWindow, ControllerSynchronizer + from spikeinterface.comparison import compare_two_sorters + + layout_dict = get_layout_description(layout_preset, layout) + + + + controllers = [] + windows = [] + for i, analyzer in enumerate(analyzers): + if verbose: + import time + t0 = time.perf_counter() + + controller = Controller( + analyzer, backend="qt", + # verbose=verbose, + verbose=False, + + curation=False, + with_traces=with_traces, + skip_extensions=skip_extensions, + ) + if verbose: + t1 = time.perf_counter() + print('controller init time', t1 - t0) + + + # Suppress a known pyqtgraph warning + warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") + warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*") + + + app = mkQApp() + + win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings) + win.setWindowTitle(f'Analyzer {i}') + # Set window icon + icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' + if icon_file.exists(): + app.setWindowIcon(QT.QIcon(str(icon_file))) + win.show() + windows.append(win) + controllers.append(controller) + + comp = compare_two_sorters(analyzers[0].sorting, analyzers[1].sorting) + + synchronizer = ControllerSynchronizer(comp, controllers, windows) + synchronizer.show() + + app.exec() \ No newline at end of file From 533434d80b9308dc74abdf94d0670e94901cb140 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Nov 2025 11:49:03 +0100 Subject: [PATCH 2/6] improvement --- spikeinterface_gui/backend_qt.py | 6 +++--- spikeinterface_gui/main.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index 830111a..34f0849 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -449,8 +449,8 @@ def __init__(self, sorting_comparison, controllers, windows, parent=None): view.notifier.unit_visibility_changed.connect(callback) settings = [ - {'name': 'mode', 'type': 'list', 'limits' : ['all', 'best', ] }, - {'name': 'thresh', 'type': 'float', 'value' : 0.05, 'step': 0.01, 'limits': (0, 1.)}, + {'name': 'mode', 'type': 'list', 'limits' : ['best', 'all',] }, + {'name': 'thresh', 'type': 'float', 'value' : 0.3, 'step': 0.01, 'limits': (0, 1.)}, ] self.settings = pg.parametertree.Parameter.create(name="settings", type='group', children=settings) @@ -509,4 +509,4 @@ def on_unit_visibility_changed(self, win_ind): f'Analyzer {win_ind} : {visible_unit_ids}\n' f'Analyzer {other_ind} : {other_visible_ids}\n' - ) \ No newline at end of file + ) diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 802ead5..d2d896c 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -389,6 +389,7 @@ def run_compare_analyzer( layout_preset=None, layout=None, verbose=False, + names=None, # user_settings=None, # disable_save_settings_button=False, ): @@ -434,7 +435,11 @@ def run_compare_analyzer( app = mkQApp() win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings) - win.setWindowTitle(f'Analyzer {i}') + if names is None: + name = f'Analyzer {i}' + else: + name = names[i] + win.setWindowTitle(name) # Set window icon icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' if icon_file.exists(): From dc007a981eb848fc6199c6a3d87da3675ec6ca81 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Nov 2025 12:23:19 +0100 Subject: [PATCH 3/6] agreement map --- spikeinterface_gui/backend_qt.py | 72 +++++++++++++++++++++++++++++--- spikeinterface_gui/main.py | 10 ++--- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index 34f0849..2b3e600 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -3,6 +3,7 @@ import markdown import numpy as np from copy import copy +import itertools import weakref @@ -423,12 +424,13 @@ def refresh(self): class ControllerSynchronizer(QT.QWidget): - def __init__(self, sorting_comparison, controllers, windows, parent=None): + def __init__(self, sorting_comparison, controllers, windows, names, parent=None): QT.QWidget.__init__(self, parent=parent) self.comp = sorting_comparison self.controllers = controllers self.windows = windows + self.names = names self.layout = QT.QVBoxLayout() self.setLayout(self.layout) @@ -461,6 +463,41 @@ def __init__(self, sorting_comparison, controllers, windows, parent=None): self.tree_settings.setWindowTitle('Settings') self.layout.addWidget(self.tree_settings) + from .utils_qt import ViewBoxHandlingClickToPositionWithCtrl + + self.graphicsview = pg.GraphicsView() + self.layout.addWidget(self.graphicsview) + self.viewBox = ViewBoxHandlingClickToPositionWithCtrl() + self.viewBox.clicked.connect(self._qt_select_pair) + self.viewBox.disableAutoRange() + self.plot = pg.PlotItem(viewBox=self.viewBox) + self.graphicsview.setCentralItem(self.plot) + self.plot.hideButtons() + self.image = pg.ImageItem() + self.plot.addItem(self.image) + self.plot.hideAxis('bottom') + self.plot.hideAxis('left') + + + import matplotlib + N = 512 + cmap_name = 'viridis' + cmap = matplotlib.colormaps[cmap_name].resampled(N) + lut = [] + for i in range(N): + r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) + lut.append([r*255,g*255,b*255]) + self.lut = np.array(lut, dtype='uint8') + + agreement = self.comp.agreement_scores.values + self.image.setImage(agreement , lut=self.lut, levels=[0, 1]) + self.image.show() + self.plot.setXRange(0, agreement.shape[0]) + self.plot.setLabel('bottom', names[0]) + self.plot.setYRange(0, agreement.shape[1]) + self.plot.setLabel('left', names[1]) + + def on_unit_visibility_changed_0(self): self.on_unit_visibility_changed(0) @@ -505,8 +542,33 @@ def on_unit_visibility_changed(self, win_ind): for view in other_window.views.values(): view.refresh() - self.label.setText( - f'Analyzer {win_ind} : {visible_unit_ids}\n' - f'Analyzer {other_ind} : {other_visible_ids}\n' + + self._refresh_label() + + def _refresh_label(self): + + txt = '' + unit_ids0 = self.controllers[0].get_visible_unit_ids() + unit_ids1 = self.controllers[1].get_visible_unit_ids() + for unit_id0, unit_id1 in itertools.product(unit_ids0, unit_ids1): + a = self.comp.agreement_scores.loc[unit_id0, unit_id1] + txt += f'{self.names[0]} unit {unit_id0} - {self.names[1]} unit {unit_id1} agreement={a}' + self.label.setText(txt) + + def _qt_select_pair(self, x, y, reset): + c0 = self.controllers[0] + c1 = self.controllers[1] + + + unit_id0 = c0.unit_ids[int(np.floor(x))] + unit_id1 = c1.unit_ids[int(np.floor(y))] - ) + c0.set_visible_unit_ids([unit_id0]) + c1.set_visible_unit_ids([unit_id1]) + + for win in self.windows: + for view in win.views.values(): + view.refresh() + + self._refresh_label() + diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index d2d896c..bddca84 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -404,7 +404,8 @@ def run_compare_analyzer( layout_dict = get_layout_description(layout_preset, layout) - + if names is None: + names = [f'Analyzer {i}' for i in range(2)] controllers = [] windows = [] @@ -435,10 +436,7 @@ def run_compare_analyzer( app = mkQApp() win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings) - if names is None: - name = f'Analyzer {i}' - else: - name = names[i] + name = names[i] win.setWindowTitle(name) # Set window icon icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' @@ -450,7 +448,7 @@ def run_compare_analyzer( comp = compare_two_sorters(analyzers[0].sorting, analyzers[1].sorting) - synchronizer = ControllerSynchronizer(comp, controllers, windows) + synchronizer = ControllerSynchronizer(comp, controllers, windows, names) synchronizer.show() app.exec() \ No newline at end of file From b0a3152c28d94bcf597d148a243d80fa1e6882b3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Nov 2025 14:06:26 +0100 Subject: [PATCH 4/6] yep --- spikeinterface_gui/backend_qt.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index 2b3e600..ea315bc 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -489,12 +489,13 @@ def __init__(self, sorting_comparison, controllers, windows, names, parent=None) lut.append([r*255,g*255,b*255]) self.lut = np.array(lut, dtype='uint8') - agreement = self.comp.agreement_scores.values - self.image.setImage(agreement , lut=self.lut, levels=[0, 1]) + # agreement = self.comp.agreement_scores.values + self.agreement_ordered = self.comp.get_ordered_agreement_scores() + self.image.setImage(self.agreement_ordered.values , lut=self.lut, levels=[0, 1]) self.image.show() - self.plot.setXRange(0, agreement.shape[0]) + self.plot.setXRange(0, self.agreement_ordered.shape[0]) self.plot.setLabel('bottom', names[0]) - self.plot.setYRange(0, agreement.shape[1]) + self.plot.setYRange(0, self.agreement_ordered.shape[1]) self.plot.setLabel('left', names[1]) @@ -552,16 +553,18 @@ def _refresh_label(self): unit_ids1 = self.controllers[1].get_visible_unit_ids() for unit_id0, unit_id1 in itertools.product(unit_ids0, unit_ids1): a = self.comp.agreement_scores.loc[unit_id0, unit_id1] - txt += f'{self.names[0]} unit {unit_id0} - {self.names[1]} unit {unit_id1} agreement={a}' + txt += f'{self.names[0]} unit {unit_id0} - {self.names[1]} unit {unit_id1} agreement={a}\n' self.label.setText(txt) def _qt_select_pair(self, x, y, reset): c0 = self.controllers[0] c1 = self.controllers[1] - - unit_id0 = c0.unit_ids[int(np.floor(x))] - unit_id1 = c1.unit_ids[int(np.floor(y))] + # used + ordered_unit_ids1 = self.agreement_ordered.index + ordered_unit_ids2 = self.agreement_ordered.columns + unit_id0 = ordered_unit_ids1[int(np.floor(x))] + unit_id1 = ordered_unit_ids2[int(np.floor(y))] c0.set_visible_unit_ids([unit_id0]) c1.set_visible_unit_ids([unit_id1]) From a3a191cdfc728adbefbd4b178945647f16da25fa Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 24 Nov 2025 13:00:31 +0100 Subject: [PATCH 5/6] oups --- spikeinterface_gui/main.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index bddca84..67878be 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -402,6 +402,9 @@ def run_compare_analyzer( from spikeinterface_gui.backend_qt import QtMainWindow, ControllerSynchronizer from spikeinterface.comparison import compare_two_sorters + app = mkQApp() + + layout_dict = get_layout_description(layout_preset, layout) if names is None: @@ -410,6 +413,7 @@ def run_compare_analyzer( controllers = [] windows = [] for i, analyzer in enumerate(analyzers): + if verbose: import time t0 = time.perf_counter() @@ -423,28 +427,28 @@ def run_compare_analyzer( with_traces=with_traces, skip_extensions=skip_extensions, ) + if verbose: t1 = time.perf_counter() print('controller init time', t1 - t0) - # Suppress a known pyqtgraph warning - warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") - warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*") + # Suppress a known pyqtgraph warning + warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") + warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*") - app = mkQApp() - win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings) - name = names[i] - win.setWindowTitle(name) - # Set window icon - icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' - if icon_file.exists(): - app.setWindowIcon(QT.QIcon(str(icon_file))) - win.show() - windows.append(win) - controllers.append(controller) + win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings) + name = names[i] + win.setWindowTitle(name) + # Set window icon + icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' + if icon_file.exists(): + app.setWindowIcon(QT.QIcon(str(icon_file))) + win.show() + windows.append(win) + controllers.append(controller) comp = compare_two_sorters(analyzers[0].sorting, analyzers[1].sorting) From 26312ce842fe438c116da5f1a87ade6966f44584 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 29 Nov 2025 12:34:56 +0100 Subject: [PATCH 6/6] test compare --- spikeinterface_gui/tests/test_compare_qt.py | 43 +++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 spikeinterface_gui/tests/test_compare_qt.py diff --git a/spikeinterface_gui/tests/test_compare_qt.py b/spikeinterface_gui/tests/test_compare_qt.py new file mode 100644 index 0000000..d3783e9 --- /dev/null +++ b/spikeinterface_gui/tests/test_compare_qt.py @@ -0,0 +1,43 @@ +from argparse import ArgumentParser +from spikeinterface_gui import run_compare_analyzer + +from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict + +from spikeinterface import load_sorting_analyzer + + +from pathlib import Path + +import numpy as np +import sys + + + + +def setup_module(): + global test_folder + case = test_folder.stem.split('_')[-1] + make_analyzer_folder(test_folder, case=case) + +def teardown_module(): + clean_all(test_folder) + + +def test_run_compare_analyzer(): + analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") + analyzers = [analyzer, analyzer] + run_compare_analyzer( + analyzers, + mode="desktop", + verbose=True, + ) + +if __name__ == '__main__': + global test_folder + + dataset = "small" + test_folder = Path(dataset).parent / f"my_dataset_{dataset}" + if not test_folder.is_dir(): + setup_module() + + win = test_run_compare_analyzer()