diff --git a/pyproject.toml b/pyproject.toml index acfa121..f436dda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface-gui" -version = '0.12.0' +version = '0.12.1' authors = [ { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 8d78a71..9dc75eb 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -13,6 +13,7 @@ from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.core.core_tools import check_json from spikeinterface.curation import validate_curation_dict +from spikeinterface.curation.curation_model import CurationModel from spikeinterface.widgets.utils import make_units_table_from_analyzer from .curation_tools import add_merge, default_label_definitions, empty_curation_data @@ -804,7 +805,8 @@ def construct_final_curation(self): d["format_version"] = "2" d["unit_ids"] = self.unit_ids.tolist() d.update(self.curation_data.copy()) - return d + model = CurationModel(**d) + return model def save_curation_in_analyzer(self): if self.analyzer.format == "memory": @@ -814,8 +816,9 @@ def save_curation_in_analyzer(self): folder = self.analyzer.folder / "spikeinterface_gui" folder.mkdir(exist_ok=True, parents=True) json_file = folder / f"curation_data.json" - with json_file.open("w") as f: - json.dump(check_json(self.construct_final_curation()), f, indent=4) + curation_model = self.construct_final_curation() + with open(json_file, "w") as f: + f.write(curation_model.model_dump_json(indent=4)) self.current_curation_saved = True elif self.analyzer.format == "zarr": import zarr @@ -823,7 +826,8 @@ def save_curation_in_analyzer(self): if "spikeinterface_gui" not in zarr_root.keys(): sigui_group = zarr_root.create_group("spikeinterface_gui", overwrite=True) sigui_group = zarr_root["spikeinterface_gui"] - sigui_group.attrs["curation_data"] = check_json(self.construct_final_curation()) + curation_model = self.construct_final_curation() + sigui_group.attrs["curation_data"] = curation_model.model_dump(mode="json") self.current_curation_saved = True def get_split_unit_ids(self): @@ -974,10 +978,14 @@ def get_unit_label(self, unit_id, category): if ix is None: return lbl = self.curation_data["manual_labels"][ix] - if category in lbl: + if "labels" in lbl and category in lbl["labels"]: + # v2 format + labels = lbl["labels"][category] + return labels[0] + elif category in lbl: + # v1 format labels = lbl[category] return labels[0] - def set_label_to_unit(self, unit_id, category, label): if label is None: @@ -987,13 +995,16 @@ def set_label_to_unit(self, unit_id, category, label): ix = self.find_unit_in_manual_labels(unit_id) if ix is not None: lbl = self.curation_data["manual_labels"][ix] - if category in lbl: - lbl[category] = [label] - else: + if "labels" in lbl and category in lbl["labels"]: + # v2 format + lbl["labels"][category] = [label] + elif category in lbl: + # v1 format lbl[category] = [label] + else: - lbl = {"unit_id": unit_id, category:[label]} - self.curation_data["manual_labels"].append(lbl) + manual_label = {"unit_id": unit_id, "labels": {category: [label]}} + self.curation_data["manual_labels"].append(manual_label) if self.verbose: print(f"Set label {category} to {label} for unit {unit_id}") @@ -1002,6 +1013,8 @@ def remove_category_from_unit(self, unit_id, category): if ix is None: return lbl = self.curation_data["manual_labels"][ix] + + # curation v1 if category in lbl: lbl.pop(category) if len(lbl) == 1: @@ -1009,3 +1022,7 @@ def remove_category_from_unit(self, unit_id, category): self.curation_data["manual_labels"].pop(ix) if self.verbose: print(f"Remove label {category} for unit {unit_id}") + # curation v2 + elif lbl.get('labels') is not None and category in lbl.get('labels'): + lbl['labels'].pop(category) + self.curation_data["manual_labels"][ix] = lbl diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 427ecd0..463d7c0 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -285,9 +285,9 @@ def _qt_export_json(self): fd.setViewMode(QT.QFileDialog.Detail) if fd.exec_(): json_file = Path(fd.selectedFiles()[0]) + curation_model = self.controller.construct_final_curation() with json_file.open("w") as f: - curation_dict = check_json(self.controller.construct_final_curation()) - json.dump(curation_dict, f, indent=4) + f.write(curation_model.model_dump_json(indent=4)) self.controller.current_curation_saved = True # PANEL @@ -529,10 +529,9 @@ def _panel_generate_json(self): # Get the path from the text input export_path = "curation.json" # Save the JSON file - curation_dict = check_json(self.controller.construct_final_curation()) - - with open(export_path, "w") as f: - json.dump(curation_dict, f, indent=4) + curation_model = self.controller.construct_final_curation() + with export_path.open("w") as f: + f.write(curation_model.model_dump_json(indent=4)) self.controller.current_curation_saved = True @@ -543,14 +542,14 @@ def _panel_generate_json(self): def _panel_submit_to_parent(self, event): """Send the curation data to the parent window""" # Get the curation data and convert it to a JSON string - curation_data = json.dumps(check_json(self.controller.construct_final_curation())) + curation_model = self.controller.construct_final_curation() # Create a JavaScript snippet that will send the data to the parent window js_code = f"""