diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index dc98056..3583f77 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -8,99 +8,46 @@ import functools import json import os +import glob import pickle import shutil import tarfile import tempfile +import re +import subprocess from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path -from shutil import copyfile +from shutil import copyfile, copytree from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np import pandas as pd -import probeinterface as pi import scipy.stats import spikeinterface as si import spikeinterface.extractors as se import spikeinterface.sorters as ss +import spikeinterface.curation as sc +import spikeinterface.exporters as sexp +import spikeinterface.qualitymetrics as sqm from scipy import interpolate from tables import HDF5ExtError from pixels import ioutils -from pixels import signal +import pixels.signal_utils as signal +import pixels.pixels_utils as xut from pixels.error import PixelsError +from pixels.constants import * +from pixels.configs import * +from pixels.stream import Stream +from pixels.units import SelectedUnits + +from common_utils.file_utils import load_yaml if TYPE_CHECKING: from typing import Optional, Literal -BEHAVIOUR_HZ = 25000 - -np.random.seed(BEHAVIOUR_HZ) - - -def _cacheable(method): - """ - Methods with this decorator will have their output cached to disk so that future - calls with the same set of arguments will simply load the result from disk. However, - if the key word argument list contains `units` and it is not either `None` or an - instance of `SelectedUnits` then this is disabled. - """ - def func(*args, **kwargs): - name = kwargs.pop("name", None) - - if "units" in kwargs: - units = kwargs["units"] - if not isinstance(units, SelectedUnits) or not hasattr(units, "name"): - return method(*args, **kwargs) - - self, *as_list = list(args) + list(kwargs.values()) - if not self._use_cache: - return method(*args, **kwargs) - - arrays = [i for i, arg in enumerate(as_list) if isinstance(arg, np.ndarray)] - if arrays: - if name is None: - raise PixelsError( - 'Cacheing methods when passing arrays requires also passing name="something"' - ) - for i in arrays: - as_list[i] = name - - as_list.insert(0, method.__name__) - output = self.interim / 'cache' / ('_'.join(str(i) for i in as_list) + '.h5') - if output.exists() and self._use_cache != "overwrite": - try: - df = ioutils.read_hdf5(output) - except HDF5ExtError: - df = None - else: - df = method(*args, **kwargs) - output.parent.mkdir(parents=True, exist_ok=True) - if df is None: - output.touch() - else: - ioutils.write_hdf5(output, df) - return df - return func - - -class SelectedUnits(list): - name: str - """ - This is the return type of Behaviour.select_units, which is a list in every way - except that when represented as a string, it can return a name, if a `name` - attribute has been set on it. This allows methods that have had `units` passed to be - cached to file. - """ - def __repr__(self): - if hasattr(self, "name"): - return self.name - return list.__repr__(self) - - class Behaviour(ABC): """ This class represents a single individual recording session. @@ -123,41 +70,62 @@ class Behaviour(ABC): """ - sample_rate = 1000 + SAMPLE_RATE = SAMPLE_RATE - def __init__(self, name, data_dir, metadata=None, interim_dir=None): + def __init__(self, name, data_dir, metadata=None, processed_dir=None, + interim_dir=None, hist_dir=None): self.name = name + self.mouse_id = name.split("_")[-1] self.data_dir = data_dir self.metadata = metadata self.raw = self.data_dir / 'raw' / self.name - self.processed = self.data_dir / 'processed' / self.name - self.files = ioutils.get_data_files(self.raw, name) if interim_dir is None: self.interim = self.data_dir / 'interim' / self.name else: - self.interim = Path(interim_dir) / self.name + self.interim = Path(interim_dir).expanduser() / self.name + + if processed_dir is None: + self.processed = self.data_dir / 'processed' / self.name + else: + self.processed = Path(processed_dir).expanduser() / self.name + self.backup = self.data_dir / 'processed' / self.name + self.backup.mkdir(parents=True, exist_ok=True) + + if hist_dir is None: + self.histology = self.data_dir / 'histology'\ + / processed / self.mouse_id + else: + self.histology = Path(hist_dir).expanduser() / self.mouse_id + + self.files = ioutils.get_data_files(self.raw, name) + + self.CatGT_dir = sorted(glob.glob( + str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' + )) self.interim.mkdir(parents=True, exist_ok=True) self.processed.mkdir(parents=True, exist_ok=True) self._action_labels = None self._behavioural_data = None - self._spike_data = None + self._ap_data = None self._spike_times_data = None self._lfp_data = None self._lag = None self._use_cache = True self._cluster_info = None + self._good_unit_info = None + self._probe_depths = None self.drop_data() - self.spike_meta = [ - ioutils.read_meta(self.find_file(f['spike_meta'])) for f in self.files - ] - self.lfp_meta = [ - ioutils.read_meta(self.find_file(f['lfp_meta'])) for f in self.files - ] + self.ap_meta = [] + for stream_id in self.files["pixels"]: + for meta in self.files["pixels"][stream_id]["ap_meta"]: + self.ap_meta.append( + ioutils.read_meta(self.find_file(meta, copy=True)) + ) # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache self.set_cache(bool(int(os.environ.get("PIXELS_CACHE", 1)))) @@ -166,13 +134,19 @@ def drop_data(self): """ Clear attributes that store data to clear some memory. """ - self._action_labels = [None] * len(self.files) - self._behavioural_data = [None] * len(self.files) - self._spike_data = [None] * len(self.files) - self._spike_times_data = [None] * len(self.files) - self._spike_rate_data = [None] * len(self.files) - self._lfp_data = [None] * len(self.files) - self._motion_index = [None] * len(self.files) + # NOTE: number of behaviour session is independent of number of probes + self.stream_count = len(self.files["pixels"]) + self.behaviour_count = len(self.files["behaviour"]["action_labels"]) + + self._action_labels = [None] * self.behaviour_count + self._behavioural_data = [None] * self.behaviour_count + self._ap_data = [None] * self.stream_count + self._spike_times_data = [None] * self.stream_count + self._spike_rate_data = [None] * self.stream_count + self._lfp_data = [None] * self.stream_count + self._motion_index = [None] * self.behaviour_count + self._cluster_info = [None] * self.stream_count + self._probe_depths = [None] * self.stream_count self._load_lag() def set_cache(self, on: bool | Literal["overwrite"]) -> None: @@ -187,27 +161,46 @@ def _load_lag(self): Load previously-calculated lag information from a saved file if it exists, otherwise return Nones. """ - lag_file = self.processed / 'lag.json' - self._lag = [None] * len(self.files) + lag_file = self.processed / "lag.json" + self._lag = [None] * self.behaviour_count if lag_file.exists(): with lag_file.open() as fd: lag = json.load(fd) for rec_num, rec_lag in enumerate(lag): - if rec_lag['lag_start'] is None: + if rec_lag["lag_start"] is None: self._lag[rec_num] = None else: - self._lag[rec_num] = (rec_lag['lag_start'], rec_lag['lag_end']) + self._lag[rec_num] = (rec_lag["lag_start"], rec_lag["lag_end"]) def get_probe_depth(self): """ Load probe depth in um from file if it has been recorded. """ - depth_file = self.processed / 'depth.txt' - if not depth_file.exists(): - msg = f": Can't load probe depth: please add it in um to processed/{self.name}/depth.txt" - raise PixelsError(msg) - with depth_file.open() as fd: - return [float(line) for line in fd.readlines()] + for stream_num, depth in enumerate(self._probe_depths): + # TODO jun 12 2024 skip stream 1 for now + if stream_num > 0: + continue + if depth is None: + try: + depth_file = self.processed / "depth.txt" + with depth_file.open() as fd: + self._probe_depths[stream_num] = [float(line) for line in + fd.readlines()][0] + except: + depth_file = self.processed / self.files[stream_num]["depth_info"] + #self._probe_depths[stream_num] = json.load(open(depth_file, mode="r"))["clustering"] + self._probe_depths[stream_num] = json.load( + open(depth_file, mode="r"))["manipulator"] + else: + msg = f": Can't load probe depth: please add it in um to\ + \nprocessed/{self.name}/depth.txt, or save it with other depth related\ + \ninfo in {self.processed / self.files[0]['depth_info']}." + raise PixelsError(msg) + + #if Path(depth_file).suffix == ".txt": + #elif Path(depth_file).suffix == ".json": + # return [json.load(open(depth_file, mode="r"))["clustering"]] + return self._probe_depths def find_file(self, name: str, copy: bool=True) -> Optional[Path]: """ @@ -229,7 +222,20 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: pathlib.Path : the full path to the desired file in the correct folder. """ + backup = self.backup / name processed = self.processed / name + if hasattr(self, "backup") and backup.exists()\ + and (not processed.exists()): + if copy: + logging.info( + f"\n {self.name}: Copying {name} to local processed" + ) + try: + copyfile(backup, processed) + except IsADirectoryError: + copytree(backup, processed) + return processed + if processed.exists(): return processed @@ -240,7 +246,7 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: raw = self.raw / name if raw.exists(): if copy: - print(f" {self.name}: Copying {name} to interim") + logging.info(f"\n {self.name}: Copying {name} to interim") copyfile(raw, interim) return interim return raw @@ -248,11 +254,11 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: tar = raw.with_name(raw.name + '.tar.gz') if tar.exists(): if copy: - print(f" {self.name}: Extracting {tar.name} to interim") + logging.info(f"\n {self.name}: Extracting {tar.name} to interim") with tarfile.open(tar) as open_tar: open_tar.extractall(path=self.interim) return interim - print(f" {self.name}: Extracting {tar.name}") + logging.info(f"\n {self.name}: Extracting {tar.name}") with tarfile.open(tar) as open_tar: open_tar.extractall(path=self.raw) return raw @@ -280,30 +286,33 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): The sync channel from either the spike or LFP data. """ - print(" Finding lag between sync channels") + # TODO jan 14 2025: + # this func is not used in vr behaviour, since they are synched + # in vd.session + logging.info("\n Finding lag between sync channels") recording = self.files[rec_num] if behavioural_data is None: - print(" Loading behavioural data") + logging.info("\n Loading behavioural data") data_file = self.find_file(recording['behaviour']) behavioural_data = ioutils.read_tdms(data_file, groups=["NpxlSync_Signal"]) behavioural_data = signal.resample( - behavioural_data.values, BEHAVIOUR_HZ, self.sample_rate + behavioural_data.values, BEHAVIOUR_HZ, self.SAMPLE_RATE ) if sync_channel is None: - print(" Loading neuropixels sync channel") + logging.info("\n Loading neuropixels sync channel") data_file = self.find_file(recording['lfp_data']) num_chans = self.lfp_meta[rec_num]['nSavedChans'] sync_channel = ioutils.read_bin(data_file, num_chans, channel=384) orig_rate = int(self.lfp_meta[rec_num]['imSampRate']) #sync_channel = sync_channel[:120 * orig_rate * 2] # 2 mins, rec Hz, back/forward - sync_channel = signal.resample(sync_channel, orig_rate, self.sample_rate) + sync_channel = signal.resample(sync_channel, orig_rate, self.SAMPLE_RATE) behavioural_data = signal.binarise(behavioural_data) sync_channel = signal.binarise(sync_channel) - print(" Finding lag") + logging.info("\n Finding lag") plot = self.processed / f'sync_{rec_num}.png' lag_start, match = signal.find_sync_lag( behavioural_data, sync_channel, plot=plot, @@ -313,8 +322,9 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): self._lag[rec_num] = (lag_start, lag_end) if match < 95: - print(" The sync channels did not match very well. Check the plot.") - print(f" Calculated lag: {(lag_start, lag_end)}") + logging.warning("\n The sync channels did not match very well. " + "Check the plot.") + logging.info(f"\n Calculated lag: {(lag_start, lag_end)}") lag_json = [] for lag in self._lag: @@ -326,16 +336,157 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): with (self.processed / 'lag.json').open('w') as fd: json.dump(lag_json, fd) + def sync_streams(self, SYNC_BIN, remap_stream_idx): + """ + Neuropixels data streams acquired simultaneously are not synchronised, unless + they are plugged into the same headstage, which is only the case for + Neuropixels 2.0 probes. + Dual-recording data acquired by Neuropixels 1.0 probes needs to be + synchronised. + Specifically, spike times from imecX stream need to be re-mapped to imec0 + time scale. + + For more info, see + https://open-ephys.github.io/gui-docs/Tutorials/Data-Synchronization.html and + https://github.com/billkarsh/TPrime. + + params + === + SYNC_BIN: int, number of rising sync edges for calculating the scaling + factor. + """ + # TODO jan 14 2025: + # this func does not work with the new self.files structure + + edges_list = [] + stream_ids = [] + #self.CatGT_dir = Path(self.CatGT_dir[0]) + output = self.ks_outputs[remap_stream_idx] / f'spike_times_remapped.npy' + + # do not redo the remapping if not necessary + if output.exists(): + logging.info("\n> Spike times from " + f"{self.ks_outputs[remap_stream_idx]} already remapped, next session.") + cluster_times = self.get_spike_times()[remap_stream_idx] + remapped_cluster_times = self.get_spike_times( + remapped=True)[remap_stream_idx] + + # get first spike time from each cluster, and their difference + clusters_first = cluster_times.iloc[0,:] + remapped_clusters_first = remapped_cluster_times.iloc[0,:] + remap = remapped_clusters_first - clusters_first + + return remap + + for rec_num, recording in enumerate(self.files): + # get file names and stuff + # TODO jan 4 check if find_file works for catgt data + spike_data = self.find_file(recording['CatGT_ap_data']) + spike_meta = self.find_file(recording['CatGT_ap_meta']) + #spike_data = self.CatGT_dir / recording['CatGT_ap_data'] + #spike_meta = self.CatGT_dir / recording['CatGT_ap_meta'] + stream_id = spike_data.as_posix()[-12:-4] + stream_ids.append(stream_id) + #self.gate_idx = spike_data.as_posix()[-18:-16] + self.trigger_idx = spike_data.as_posix()[-15:-13] + + # find extracted rising sync edges, rn from CatGT + try: + edges_file = sorted(glob.glob( + rf'{self.CatGT_dir}' + f'/*{stream_id}.xd*.txt', recursive=True) + )[0] + except IndexError as e: + raise PixelsError( + f"Can't load sync pulse rising edges. Did you run CatGT and\ + extract edges? Full error: {e}\n" + ) + # read sync edges, ms + edges = np.loadtxt(edges_file) + # pick edges by defined bin + binned_edges = np.append( + arr=edges[0::SYNC_BIN], + values=edges[-1], + ) + edges_list.append(binned_edges) + + # make list np array and calculate difference between streams to get the + # initial difference + edges = np.array(edges_list) + initial_dt = np.diff(edges, axis=0).squeeze() + # save initial diff for later plotting + np.save(self.processed / 'sync_streams_lag.npy', initial_dt) + + # load spike times that needs to be remapped + times = self.ks_outputs[remap_stream_idx] / f'spike_times.npy' + try: + times = np.load(times) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) + times = np.squeeze(times) + + # convert spike times to ms + orig_rate = int(self.spike_meta[0]['imSampRate']) + times_ms = times * self.SAMPLE_RATE / orig_rate + + lag = [None, 'later', 'earlier'] + logging.info(f"""\n> {stream_ids[0]} started\r + {abs(initial_dt[0]*1000):.2f}ms {lag[int(np.sign(initial_dt[0]))]}\r + and finished\r + {abs(initial_dt[-1]*1000):.2f}ms {lag[-int(np.sign(initial_dt[0]))]}\r + than {stream_ids[1]}.""") + + # create a np array for remapped spike times from imec1 + remapped_times_ms = np.zeros(times.shape) + + # get edge difference within & between streams + within_streams = np.diff(edges) + # calculate scaling factor for each sync bin + scales = within_streams[0] / within_streams[-1] + + # find out which sync period/bin each spike time belongs to, and use + # the scale from that period/bin + for t, time in enumerate(times_ms): + bin_idx = np.where(time > edges[remap_stream_idx])[0][0] + remapped_times_ms[t] = ((time - edges[remap_stream_idx][bin_idx]) * + scales[bin_idx]) + edges[0][bin_idx] + + logging.info(f"""\n> Remap stats {stream_ids[remap_stream_idx]} spike times:\r + median shift {np.median(remapped_times_ms-times_ms):.2f}ms,\r + min shift {np.min(remapped_times_ms-times_ms):.2f}ms,\r + max shift {np.max(remapped_times_ms-times_ms):.2f}ms.""") + + # convert remappmed times back to its original sample index + remapped_times = np.uint64(remapped_times_ms * orig_rate / self.SAMPLE_RATE) + np.save(output, remapped_times) + logging.info(f'\n> Spike times remapping output saved to\n {output}.') + + # load remapped spike times of each cluster + cluster_times = self.get_spike_times()[remap_stream_idx] + remapped_cluster_times = self.get_spike_times( + remapped=True)[remap_stream_idx] + + # get first spike time from each cluster, and their difference + clusters_first = cluster_times.iloc[0,:] + remapped_clusters_first = remapped_cluster_times.iloc[0,:] + remap = remapped_clusters_first - clusters_first + + return remap + + def process_behaviour(self): """ Process behavioural data from raw tdms and align to neuropixels data. """ + # NOTE jan 14 2025: + # this func is not used by vr behaviour for rec_num, recording in enumerate(self.files): - print( - f">>>>> Processing behaviour for recording {rec_num + 1} of {len(self.files)}" + logging.info( + f"\n>>>>> Processing behaviour for recording {rec_num + 1}" + f" of {len(self.files)}" ) - print(f"> Loading behavioural data") + logging.info(f"\n> Loading behavioural data") behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour'])) # ignore any columns that have Nans; these just contain settings @@ -343,12 +494,12 @@ def process_behaviour(self): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - print(f"> Downsampling to {self.sample_rate} Hz") - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + logging.info(f"\n> Downsampling to {self.SAMPLE_RATE} Hz") + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] - print(f"> Syncing to Neuropixels data") + logging.info(f"\n> Syncing to Neuropixels data") if self._lag[rec_num] is None: self.sync_data( rec_num, @@ -358,152 +509,388 @@ def process_behaviour(self): behavioural_data = behavioural_data[max(lag_start, 0):-1-max(lag_end, 0)] behavioural_data.index = range(len(behavioural_data)) - print(f"> Extracting action labels") + logging.info(f"\n> Extracting action labels") self._action_labels[rec_num] = self._extract_action_labels(rec_num, behavioural_data) output = self.processed / recording['action_labels'] np.save(output, self._action_labels[rec_num]) - print(f"> Saved to: {output}") + logging.info(f"\n> Saved to: {output}") output = self.processed / recording['behaviour_processed'] - print(f"> Saving downsampled behavioural data to:") - print(f" {output}") + logging.info(f"\n> Saving downsampled behavioural data to:") + logging.info(f"\n {output}") behavioural_data.drop("/'NpxlSync_Signal'/'0'", axis=1, inplace=True) ioutils.write_hdf5(output, behavioural_data) self._behavioural_data[rec_num] = behavioural_data - print("> Done!") + logging.info("\n> Done!") + - def process_spikes(self): + def correct_ap_motion(self, mc_method="dredge"): """ - Process the spike data from the raw neural recording data. + Correct motion of recording. + + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) + "ks": let kilosort do motion correction. + + return + === + None """ - for rec_num, recording in enumerate(self.files): - print( - f">>>>> Processing spike data for recording {rec_num + 1} of {len(self.files)}" + if mc_method == "ks": + logging.info(f"\n> Correct motion later with {mc_method}.") + return None + + # get pixels streams + streams = self.files["pixels"] + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + output = self.find_file(stream_files["ap_motion_corrected"]) + if output: + logging.info(f"\n> {stream_id} already motion corrected.") + stream_files["ap_motion_corrected"] = si.load(output) + continue + else: + output = self.processed / stream_files["ap_motion_corrected"] + + logging.info( + f"\n>>>>> Correcting motion for ap band from {stream_id} " + f"in total of {self.stream_count} stream(s) with {mc_method}" + ) + + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.correct_ap_motion() + + stream_files["ap_motion_corrected"].save( + format="zarr", + folder=output, + compressor=wv_compressor, + ) + + if hasattr(self, "backup"): + # copy to backup if backup setup + # zarr is a directory hence use copytree + copytree(output, self.backup / output.name) + logging.info(f"\n> Motion corrected copied to {self.backup}.") + + return None + + + def correct_lfp_motion(self): + raise NotImplementedError("> Not implemented.") + + + def preprocess_raw(self): + """ + Preprocess full-band raw pixels data. + + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) + "ks": let kilosort do motion correction. + + return + === + None + """ + # get pixels streams + streams = self.files["pixels"] + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, ) - output = self.processed / recording['spike_processed'] + stream.preprocess_raw() + + return None + + + def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): + """ + Get a sense of possible drifts in the recordings by looking at a + "positional raster plot", i.e. the depth of the spike as function of + time. To do so, we need to detect the peaks, and then to localize them + in space. + + params + === + rec: spikeinterface recording extractor. + + loc_method: str, peak location method. + Default: "monopolar_triangulation" + list of methods: + "center_of_mass", "monopolar_triangulation", "grid_convolution" + to learn more, check: + https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html + """ + self.extract_bands("ap") + # get pixels streams + streams = self.files["pixels"] + + for stream_id, stream_files in streams.items(): + output = self.processed / stream_files["detected_peaks"] if output.exists(): + logging.info(f"\n> Peaks from {stream_id} already detected.") continue - data_file = self.find_file(recording['spike_data']) - orig_rate = self.spike_meta[rec_num]['imSampRate'] - num_chans = self.spike_meta[rec_num]['nSavedChans'] + # get ap band + rec = stream_files["ap_extracted"] - print("> Mapping spike data") - data = ioutils.read_bin(data_file, num_chans) + # detect and localise peaks + df = xut.detect_n_localise_peaks(rec) + + # write to disk + ioutils.write_hdf5(output, df) + + if hasattr(self, "backup"): + # copy to backup if backup setup + copyfile(output, self.backup / output.name) + logging.info(f"\n> Detected peaks copied to {self.backup}.") + + return None - print(f"> Downsampling to {self.sample_rate} Hz") - data = signal.resample(data, orig_rate, self.sample_rate) - # Ideally we would median subtract before downsampling, but that takes a - # very long time and is at risk of memory errors, so let's do it after. - print("> Performing median subtraction across rows") - data = signal.median_subtraction(data, axis=0) - print("> Performing median subtraction across columns") - data = signal.median_subtraction(data, axis=1) + def extract_bands(self, freqs=None, preprocess=True): + """ + extract data of ap and lfp frequency bands from the raw neural recording + data. + """ + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + logging.info( + f"\n>>>>> Extracting bands from {self.name} " + f"{stream_id} in total of {self.stream_count} stream(s)" + ) + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.extract_bands(freqs, preprocess) + """ if self._lag[rec_num] is None: self.sync_data(rec_num, sync_channel=data[:, -1]) lag_start, lag_end = self._lag[rec_num] - print(f"> Saving data to {output}") + logging.info(f"\n> Saving data to {output}") if lag_end < 0: data = data[:lag_end] if lag_start < 0: data = data[- lag_start:] data = pd.DataFrame(data[:, :-1]) ioutils.write_hdf5(output, data) + """ - def process_lfp(self): - """ - Process the LFP data from the raw neural recording data. + return None + + + def run_catgt(self, CatGT_app=None, args=None) -> None: """ - for rec_num, recording in enumerate(self.files): - print(f">>>>> Processing LFP for recording {rec_num + 1} of {len(self.files)}") + This func performs CatGT on copied AP data in the interim. - data_file = self.find_file(recording['lfp_data']) - orig_rate = self.lfp_meta[rec_num]['imSampRate'] - num_chans = self.lfp_meta[rec_num]['nSavedChans'] + params + ==== + data_dir: path, dir to interim data and catgt output. - print("> Mapping LFP data") - data = ioutils.read_bin(data_file, num_chans) + catgt_app: path, dir to catgt software. - output = self.processed / recording['lfp_processed'] - if output.exists(): - continue + args: str, arguments in catgt. + default is None. + """ + # TODO jan 14 2025: + # this func is deprecated + assert 0, "deprecated" + if CatGT_app == None: + CatGT_app = "~/CatGT3.4" + # move cwd to catgt + os.chdir(CatGT_app) + + # reset catgt args for current session + session_args = None + + for f in self.files: + # copy spike data to interim + self.find_file(f['spike_data']) + + if (isinstance(self.CatGT_dir, list) and + len(self.CatGT_dir) != 0 and + len(os.listdir(self.CatGT_dir[0])) != 0): + logging.info(f"\nCatGT already performed on ap data of {self.name}." + " Next session.\n") + return + else: + #TODO: finish this here so that catgt can run together with sorting + logging.info(f"\n> Running CatGT on ap data of {self.name}") + #_dir = self.interim + + if args == None: + #args = f"-no_run_fld\ + # -g=0,9\ + # -t=0,9\ + # -prb=0:1\ + # -prb_miss_ok\ + # -ap\ + # -lf\ + # -apfilter=butter,12,300,9000\ + # -lffilter=butter,12,0.5,300\ + # -xd=2,0,384,6,350,160\ + # -gblcar\ + # -gfix=0.2,0.1,0.02" + args = f"-no_run_fld\ + -g=0\ + -t=0\ + -prb=0\ + -prb_miss_ok\ + -ap\ + -xd=2,0,384,6,20,15\ + -xid=2,0,384,6,20,15" + + session_args = f"-dir={self.interim} -run={self.name} -dest={self.interim} " + args + logging.info(f"\ncatgt args of {self.name}: \n{session_args}") + + subprocess.run( ['./run_catgt.sh', session_args]) + + # make sure CatGT_dir is set after running + self.CatGT_dir = sorted(glob.glob( + str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' + )) + + def load_raw_ap(self): + """ + Load and concatenate raw recording files for each stream (i.e., probe), + so that data from all runs of the same probe can be preprocessed and + sorted together. + """ + # if multiple runs for the same probe, concatenate them + streams = self.files["pixels"] + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) - print("> Performing median subtraction across channels for each timepoint") - subtracted = signal.median_subtraction(data, axis=1) + logging.info( + f"\n> Loading raw {stream_id} data." + ) + raw_rec = stream.load_raw_ap() - print(f"> Downsampling to {self.sample_rate} Hz") - downsampled = signal.resample(subtracted, orig_rate, self.sample_rate, False) - sync_chan = downsampled[:, -1] - downsampled = downsampled[:, :-1] + # now the value for streams dict is recording extractor + stream_files["si_rec"] = raw_rec - if self._lag[rec_num] is None: - self.sync_data(rec_num, sync_channel=data[:, -1]) - lag_start, lag_end = self._lag[rec_num] + return None - sd = self.processed / recording['lfp_sd'] - if sd.exists(): - continue - SDs = [] - for i in range(downsampled.shape[1]): - SDs.append(np.std(downsampled[:, i])) - results = dict( - median=np.median(SDs), - SDs=SDs, + def whiten_ap(self): + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, ) - print(f"> Saving standard deviation (and their median) of each channel") - with open(sd, 'w') as fd: - json.dump(results, fd) + logging.info(f"\n> Whitening {stream_id}.") + stream.whiten_ap() + + return None - if lag_end < 0: - data = data[:lag_end] - if lag_start < 0: - data = data[- lag_start:] - print(f"> Saving median subtracted & downsampled LFP to {output}") - # save in .npy format - np.save( - file=output, - arr=downsampled, - allow_pickle=True, - ) - #downsampled = pd.DataFrame(downsampled) - #ioutils.write_hdf5(output, downsampled) - def sort_spikes(self): + def sort_spikes(self, mc_method="ks"): """ Run kilosort spike sorting on raw spike data. + + params + === + mc_method: str, motion correction method. + Default: "ks". + (as of may 2025, feeding ap dredged recording to kilosort 4 + gives much less unit, so just let ks4 does its thing.) + "ks": do motion correction with kilosort. + "dredge": do motion correction with dredge on ap band. """ - streams = {} + ks_image_path = self.interim.parent / ks4_image_name - for rec_num, files in enumerate(self.files): - data_file = self.find_file(files['spike_data']) - assert data_file, f"Spike data not found for {files['spike_data']}." + if not ks_image_path.exists(): + raise PixelsError("Have you craeted Singularity image for sorting?") - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: - metadata = self.find_file(files['spike_meta']) - streams[stream_id] = metadata + # ap band motion correct ONLY for building sorting analyser + self.correct_ap_motion() - for stream_num, stream in enumerate(streams.items()): - stream_id, metadata = stream - try: - recording = se.SpikeGLXRecordingExtractor(self.interim, stream_id=stream_id) - except ValueError as e: - raise PixelsError( - f"Did the raw data get fully copied to interim? Full error: {e}" - ) + if mc_method == "ks": + ks_mc = True + else: + ks_mc = False + # XXX: no whitening + #self.whiten_ap() + + # set ks4 parameters + ks4_params = { + "do_CAR": False, # do not common average reference + "skip_kilosort_preprocessing": False, + "do_correction": ks_mc, + "save_preprocessed_copy": True, # save ks4 preprocessed data + } + + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + # check if already sorted and exported + sorter_output = self.find_file( + stream_files["sorting_analyser"].parent + ) + sa_dir = self.processed / stream_files["sorting_analyser"] + if sa_dir.exists(): + logging.info("\n> Already sorted and exported, next stream.") + continue + + # get catgt directory + catgt_dir = self.find_file( + stream_files["CatGT_ap_data"][stream_num] + ) + + # find spike sorting output folder + if catgt_dir is None: + output = sa_dir.parent + else: + output = self.processed / f"sorted_stream_cat_{stream_num}" - print("> Running kilosort") - output = self.processed / f'sorted_stream_{stream_num}' - concat_rec = si.concatenate_recordings([recording]) - probe = pi.read_spikeglx(metadata.as_posix()) - concat_rec = concat_rec.set_probe(probe) - ss.run_kilosort3(recording=concat_rec, output_folder=output) + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.sort_spikes( + ks_mc=ks_mc, + ks_image_path=ks_image_path, + ks4_params=ks4_params, + output=output, + sa_dir=sa_dir, + ) + if hasattr(self, "backup"): + # copy to backup if backup setup + copytree(output, self.backup / output.name) + logging.info(f"\n> Sorter ourput copied to {self.backup}.") + + return None def extract_videos(self, force=False): @@ -546,7 +933,7 @@ def configure_motion_tracking(self, project: str) -> None: videos.append(self.interim / video.with_suffix('.avi')) if not videos: - print(self.name, ": No matching videos for project:", project) + logging.info(self.name, ": No matching videos for project:", project) return if config: @@ -556,7 +943,7 @@ def configure_motion_tracking(self, project: str) -> None: copy_videos=False, ) else: - print(f"Config not found.") + logging.warning(f"\nConfig not found.") reply = input("Create new project? [Y/n]") if reply and reply[0].lower() == "n": raise PixelsError("A DLC project is needed for motion tracking.") @@ -674,7 +1061,7 @@ def _align_dlc_coords(self, rec_num, metadata, coords): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] @@ -852,6 +1239,7 @@ def process_motion_index(self, video_match): # Get MIs avi = self.interim / video.with_suffix('.avi') + #TODO: how to load recording rec_rois, roi_file = ses_rois[(rec_num, v)] rec_mi = signal.motion_index(avi, rec_rois) @@ -863,7 +1251,7 @@ def process_motion_index(self, video_match): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] trigger = signal.binarise(behavioural_data["/'CamFrames'/'0'"]).values @@ -933,14 +1321,14 @@ def add_motion_index_action_label( action_labels = self.get_action_labels() motion_indexes = self.get_motion_index_data() - scan_duration = self.sample_rate * 10 + scan_duration = self.SAMPLE_RATE * 10 half = scan_duration // 2 # We take 200 ms before the action begins as a short baseline period for each # trial. The smallest standard deviation of all SDs of these baseline periods is # used as a threshold to identify "clean" trials (`clean_threshold` below). # Non-clean trials are trials where TODO - short_pre = int(0.2 * self.sample_rate) + short_pre = int(0.2 * self.SAMPLE_RATE) for rec_num, recording in enumerate(self.files): # Only recs with camera_data will have motion indexes @@ -1001,7 +1389,7 @@ def _extract_action_labels(self, behavioural_data): """ - def _get_processed_data(self, attr, key): + def _get_processed_data(self, attr, key, category): """ Used by the following get_X methods to load processed data. @@ -1018,18 +1406,28 @@ def _get_processed_data(self, attr, key): saved = getattr(self, attr) if saved[0] is None: - for rec_num, recording in enumerate(self.files): - if key in recording: - file_path = self.processed / recording[key] - if file_path.exists(): - if file_path.suffix == '.npy': - saved[rec_num] = np.load(file_path) - elif file_path.suffix == '.h5': - saved[rec_num] = ioutils.read_hdf5(file_path) - else: - msg = f"Could not find {attr[1:]} for recording {rec_num}." - msg += f"\nFile should be at: {file_path}" - raise PixelsError(msg) + files = self.files[category] + else: + return saved + + if key in files: + dirs = files[key] + for f, file_dir in enumerate(dirs): + file_path = self.find_file(file_dir) + try: + assert (file_path is not None) + except: + file_path = self.processed / file_dir + if file_path.exists(): + if re.search(r'\.np[yz]$', file_path.suffix): + saved[f] = np.load(file_path) + elif file_path.suffix == '.h5': + saved[f] = ioutils.read_hdf5(file_path) + else: + msg = f"Could not find {attr[1:]} for recording {rec_num}." + msg += f"\nFile should be at: {file_path}" + raise PixelsError(msg) + return saved def get_action_labels(self): @@ -1037,7 +1435,8 @@ def get_action_labels(self): Returns the action labels, either from self._action_labels if they have been loaded already, or from file. """ - return self._get_processed_data("_action_labels", "action_labels") + return self._get_processed_data("_action_labels", "action_labels", + "behaviour") def get_behavioural_data(self): """ @@ -1091,7 +1490,7 @@ def get_spike_data(self): """ Returns the processed and downsampled spike data. """ - return self._get_processed_data("_spike_data", "spike_processed") + return self._get_processed_data("_ap_data", "spike_processed") def get_lfp_data(self): """ @@ -1099,131 +1498,111 @@ def get_lfp_data(self): """ return self._get_processed_data("_lfp_data", "lfp_processed") - def _get_spike_times(self): + + def _get_si_spike_times(self, units): """ - Returns the sorted spike times. + get spike times in second with spikeinterface """ - saved = self._spike_times_data - if saved[0] is None: - times = self.processed / f'sorted_stream_0' / 'spike_times.npy' - clust = self.processed / f'sorted_stream_0' / 'spike_clusters.npy' + spike_times = self._spike_times_data - try: - times = np.load(times) - clust = np.load(clust) - except FileNotFoundError: - msg = ": Can't load spike times that haven't been extracted!" - raise PixelsError(self.name + msg) + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + # find sorting analyser, use merged if there is one + merged_sa_dir = self.find_file( + stream_files["merged_sorting_analyser"] + ) + if merged_sa_dir: + sa_dir = merged_sa_dir + else: + sa_dir = self.find_file(stream_files["sorting_analyser"]) + # load sorting analyser + temp_sa = si.load_sorting_analyzer(sa_dir) + + # select units + sorting = temp_sa.sorting.select_units(units) + sa = temp_sa.select_units(units) + sa.sorting = sorting + + times = {} + # get spike train + for i, unit_id in enumerate(sa.unit_ids): + unit_times = sa.sorting.get_unit_spike_train( + unit_id=unit_id, + return_times=False, + ) + times[unit_id] = pd.Series(unit_times) + # concatenate units + spike_times[stream_num] = pd.concat( + objs=times, + axis=1, + names="unit", + ) + # get sampling frequency + fs = int(sa.sampling_frequency) + # Convert to time into sample rate index + spike_times[stream_num] /= fs / self.SAMPLE_RATE - times = np.squeeze(times) - clust = np.squeeze(clust) - by_clust = {} + return spike_times[0] # NOTE: only deal with one stream for now - for c in np.unique(clust): - by_clust[c] = pd.Series(times[clust == c]).drop_duplicates() - saved[0] = pd.concat(by_clust, axis=1, names=['unit']) - return saved[0] - def _get_aligned_spike_times( - self, label, event, duration, rate=False, sigma=None, units=None - ): + def get_spike_times(self, units, remapped=False, use_si=False): """ - Returns spike times for each unit within a given time window around an event. - align_trials delegates to this function, and should be used for getting aligned - data in scripts. - """ - action_labels = self.get_action_labels() - - if units is None: - units = self.select_units() - - spikes = self._get_spike_times()[units] - # Convert to ms (self.sample_rate) - spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate - - if rate: - # pad ends with 1 second extra to remove edge effects from convolution - duration += 2 + Returns the sorted spike times. - scan_duration = self.sample_rate * 8 - half = int((self.sample_rate * duration) / 2) - cursor = 0 # In sample points - i = -1 - rec_trials = {} + params + === + remapped: bool, if using remapped (synced with imec0) spike times. + Default: False + """ + spike_times = self._spike_times_data - for rec_num in range(len(self.files)): - actions = action_labels[rec_num][:, 0] - events = action_labels[rec_num][:, 1] - trial_starts = np.where(np.bitwise_and(actions, label))[0] + for stream_num, stream in enumerate(range(len(spike_times))): + if use_si: + spike_times[stream_num] = self._get_si_spike_times(units) + else: + if remapped and stream_num > 0: + times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' + logging.info(f"""\n> Found remapped spike times from\r + {self.ks_outputs[stream_num]}, try to load this.""") + else: + times = self.ks_outputs[stream_num] / f'spike_times.npy' - # Account for multiple raw data files - meta = self.spike_meta[rec_num] - samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 - assert samples.is_integer() - milliseconds = samples / 30 - cursor_duration = cursor / 30 - rec_spikes = spikes[ - (cursor_duration <= spikes) & (spikes < (cursor_duration + milliseconds)) - ] - cursor_duration - cursor += samples - - # Account for lag, in case the ephys recording was started before the - # behaviour - lag_start, _ = self._lag[rec_num] - if lag_start < 0: - rec_spikes = rec_spikes + lag_start + clust = self.ks_outputs[stream_num] / f'spike_clusters.npy' - for i, start in enumerate(trial_starts, start=i + 1): - centre = np.where(np.bitwise_and(events[start:start + scan_duration], event))[0] - if len(centre) == 0: - # See comment in align_trials as to why we just continue instead of - # erroring like we used to here. - print("No event found for an action. If this is OK, ignore this.") - continue - centre = start + centre[0] + try: + times = np.load(times) + clust = np.load(clust) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) + + times = np.squeeze(times) + clust = np.squeeze(clust) + by_clust = {} + + for c in np.unique(clust): + c_times = times[clust == c] + uniques, counts = np.unique( + c_times, + return_counts=True, + ) + repeats = c_times[np.where(counts>1)] + if len(repeats>1): + logging.info(f"\n> removed {len(repeats)} double-counted " + "spikes from cluster {c}.") - trial = rec_spikes[centre - half < rec_spikes] - trial = trial[trial <= centre + half] - trial = trial - centre - tdf = [] - - for unit in trial: - u_times = trial[unit].values - u_times = u_times[~np.isnan(u_times)] - u_times = np.unique(u_times) # remove double-counted spikes - udf = pd.DataFrame({int(unit): u_times}) - tdf.append(udf) - - assert len(tdf) == len(units) - if tdf: - tdfc = pd.concat(tdf, axis=1) - if rate: - tdfc = signal.convolve(tdfc, duration * 1000, sigma) - rec_trials[i] = tdfc - - if not rec_trials: - return None + by_clust[c] = pd.Series(uniques) + spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) + # Convert to time into sample rate index + spike_times[stream_num] /= int(self.spike_meta[0]['imSampRate'])\ + / self.SAMPLE_RATE - trials = pd.concat(rec_trials, axis=1, names=["trial", "unit"]) - trials = trials.reorder_levels(["unit", "trial"], axis=1) - trials = trials.sort_index(level=0, axis=1) - - if rate: - # Set index to seconds and remove the padding 1 sec at each end - points = trials.shape[0] - start = (- duration / 2) + (duration / points) - # Having trouble with float values - #timepoints = np.linspace(start, duration / 2, points, dtype=np.float64) - timepoints = list(range(round(start * 1000), int(duration * 1000 / 2) + 1)) - trials['time'] = pd.Series(timepoints, index=trials.index) / 1000 - trials = trials.set_index('time') - trials = trials.iloc[self.sample_rate : - self.sample_rate] + return spike_times[0] - return trials def select_units( - self, group='good', min_depth=0, max_depth=None, min_spike_width=None, - max_spike_width=None, uncurated=False, name=None + self, min_depth=0, max_depth=None, min_spike_width=None, + unit_kwargs=None, max_spike_width=None, name=None, ): """ Select units based on specified criteria. The output of this can be passed to @@ -1231,10 +1610,6 @@ def select_units( Parameters ---------- - group : str, optional - The group to which the units that are wanted are part of. One of: 'group', - 'mua', 'noise' or None. Default is 'good'. - min_depth : int, optional (Only used when getting spike data). The minimum depth that units must be at to be included. Default is 0 i.e. in the brain. @@ -1251,9 +1626,6 @@ def select_units( (Only used when getting spike data). The maximum median spike width that units must have to be included. Default is None i.e. no maximum. - uncurated : bool, optional - Use uncurated units. Default: False. - name : str, optional Give this selection of units a name. This allows the list of units to be represented as a string, which enables caching. Future calls to cacheable @@ -1262,52 +1634,73 @@ def select_units( is the same between uses of the same name. """ - cluster_info = self.get_cluster_info() - selected_units = SelectedUnits() - if name is not None: - selected_units.name = name - - if min_depth is not None or max_depth is not None: - probe_depth = self.get_probe_depth()[0] - - if min_spike_width == 0: - min_spike_width = None - if min_spike_width is not None or max_spike_width is not None: - widths = self.get_spike_widths() - else: - widths = None - - rec_num = 0 - - id_key = 'id' if 'id' in cluster_info else 'cluster_id' - grouping = 'KSLabel' if uncurated else 'group' - - for unit in cluster_info[id_key]: - unit_info = cluster_info.loc[cluster_info[id_key] == unit].iloc[0].to_dict() + selected_units = SelectedUnits(name=name) if name is not None\ + else SelectedUnits() + + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + # find sorting analyser, use merged if there is one + merged_sa_dir = self.find_file( + stream_files["merged_sorting_analyser"] + ) + if merged_sa_dir: + sa_dir = merged_sa_dir + else: + sa_dir = self.find_file(stream_files["sorting_analyser"]) - # we only want units that are in the specified group - if not group or unit_info[grouping] == group: + # load sorting analyser + temp_sa = si.load_sorting_analyzer(sa_dir) + # NOTE: si.load gives warning when using temp_wh.dat to build + # sorting analyser, and to load sa it will be loading binary + # recording obj, and that checks version - # and that are within the specified depth range - if min_depth is not None: - if probe_depth - unit_info['depth'] < min_depth: - continue - if max_depth is not None: - if probe_depth - unit_info['depth'] > max_depth: - continue - - # and that have the specified median spike widths - if widths is not None: - width = widths[widths['unit'] == unit]['median_ms'] - assert len(width.values) == 1 - if min_spike_width is not None: - if width.values[0] < min_spike_width: - continue - if max_spike_width is not None: - if width.values[0] > max_spike_width: - continue + # remove noisy units + try: + noisy_units = load_yaml( + path=self.find_file(stream_files["noisy_units"]), + ) + except: + raise PixelsError("> Have you labelled noisy units?") + # remove units from sorting and reattach to sa to keep properties + sorting = temp_sa.sorting.remove_units(remove_unit_ids=noisy_units) + sa = temp_sa.remove_units(remove_unit_ids=noisy_units) + sa.sorting = sorting + + # get units + unit_ids = sa.unit_ids + + if name == "all": + selected_units[stream_id] = unit_ids + continue - selected_units.append(unit) + # get shank id for units + shank_ids = sa.sorting.get_property("group") + + # get coordinates of channel with max. amplitude + max_chan_coords = sa.sorting.get_property("max_chan_coords") + # get depths + depths = max_chan_coords[:, 1] + + if unit_kwargs: + for shank_id, kwargs in unit_kwargs.items(): + # get shank depths + min_depth = kwargs["min_depth"] + max_depth = kwargs["max_depth"] + # find units + in_range = unit_ids[ + (depths >= min_depth) & (depths < max_depth) &\ + (shank_ids == shank_id) + ] + # add to list + selected_units.extend(stream_id, in_range) + else: + # if there is only one shank + # find units + in_range = unit_ids[ + (depths >= min_depth) & (depths < max_depth) + ] + # add to list + selected_units.extend(stream_id, in_range) return selected_units @@ -1315,10 +1708,10 @@ def _get_neuro_raw(self, kind): raw = [] meta = getattr(self, f"{kind}_meta") for rec_num, recording in enumerate(self.files): - data_file = self.find_file(recording[f'{kind}_data']) + data_file = self.find_file(recording[f'{kind}_data'], copy=False) orig_rate = int(meta[rec_num]['imSampRate']) num_chans = int(meta[rec_num]['nSavedChans']) - factor = orig_rate / self.sample_rate + factor = orig_rate / self.SAMPLE_RATE data = ioutils.read_bin(data_file, num_chans) @@ -1348,10 +1741,11 @@ def get_lfp_data_raw(self): """ return self._get_neuro_raw('lfp') - @_cacheable + def align_trials( - self, label, event, data='spike_times', raw=False, duration=1, sigma=None, - units=None, dlc_project=None, video_match=None, + self, label, event, units=None, data='spike_event', raw=False, + duration=1, sigma=None, dlc_project=None, video_match=None, + end_event=None, ): """ Get trials aligned to an event. This finds all instances of label in the action @@ -1394,28 +1788,58 @@ def align_trials( When aligning video or motion index data, use this fnmatch pattern to select videos. + end_event : int | None + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. """ data = data.lower() data_options = [ 'behavioural', # Channels from behaviour TDMS file 'spike', # Raw/downsampled channels from probe (AP) - 'spike_times', # List of spike times per unit - 'spike_rate', # Spike rate signals from convolved spike times 'lfp', # Raw/downsampled channels from probe (LFP) 'motion_index', # Motion index per ROI from the video 'motion_tracking', # Motion tracking coordinates from DLC + 'spike_trial', # Taking spike times from the whole duration of each + # trial, convolve into spike rate, output also + # contains times. + 'spike_event', # Taking spike times from +/- 2s of an event, + # convolve into spike rate, output also + # contains times. ] if data not in data_options: raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") - if data in ("spike_times", "spike_rate"): - print(f"Aligning {data} to trials.") - # we let a dedicated function handle aligning spike times - return self._get_aligned_spike_times( - label, event, duration, rate=data == "spike_rate", sigma=sigma, - units=units + streams = self.files["pixels"] + output = {} + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, ) + output[stream_id] = stream.align_trials( + units=units, # NOTE: ALWAYS the first arg + data=data, # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + + if units.name == "all" and event.name == "trial_start"\ + and end_event.name == "trial_end": + stream.get_spike_chance( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + + return output if data == "motion_tracking" and not dlc_project: raise PixelsError("When aligning to 'motion_tracking', dlc_project is needed.") @@ -1423,32 +1847,32 @@ def align_trials( action_labels = self.get_action_labels() if raw: - print(f"Aligning raw {data} data to trials.") + logging.info(f"\nAligning raw {data} data to trials.") getter = getattr(self, f"get_{data}_data_raw", None) if not getter: raise PixelsError(f"align_trials: {data} doesn't have a 'raw' option.") - values, sample_rate = getter() + values, SAMPLE_RATE = getter() else: - print(f"Aligning {data} data to trials.") + logging.info(f"\nAligning {data} data to trials.") if dlc_project: values = self.get_motion_tracking_data(dlc_project) elif data == "motion_index": values = self.get_motion_index_data(video_match) else: values = getattr(self, f"get_{data}_data")() - sample_rate = self.sample_rate + SAMPLE_RATE = self.SAMPLE_RATE if not values or values[0] is None: raise PixelsError(f"align_trials: Could not get {data} data.") trials = [] # The logic here is that the action labels will always have a sample rate of - # self.sample_rate, whereas our data here may differ. 'duration' is used to scan + # self.SAMPLE_RATE, whereas our data here may differ. 'duration' is used to scan # the action labels, so always give it 5 seconds to scan, then 'half' is used to # index data. - scan_duration = self.sample_rate * 10 - half = (sample_rate * duration) // 2 + scan_duration = self.SAMPLE_RATE * 10 + half = (SAMPLE_RATE * duration) // 2 if isinstance(half, float): assert half.is_integer() # In case duration is a float < 1 half = int(half) @@ -1476,10 +1900,11 @@ def align_trials( # here to warn the user in case it is an error, while otherwise # continuing. #raise PixelsError('Action labels probably miscalculated') - print("No event found for an action. If this is OK, ignore this.") + logging.info("\nNo event found for an action. If this is" + " OK, ignore this.") continue centre = start + centre[0] - centre = int(centre * sample_rate / self.sample_rate) + centre = int(centre * SAMPLE_RATE / self.SAMPLE_RATE) trial = values[rec_num][centre - half + 1:centre + half + 1] if isinstance(trial, np.ndarray): @@ -1527,8 +1952,8 @@ def align_clips(self, label, event, video_match, duration=1): """ action_labels = self.get_action_labels() - scan_duration = self.sample_rate * 8 - half = int((self.sample_rate * duration) / 2) + scan_duration = self.SAMPLE_RATE * 8 + half = int((self.SAMPLE_RATE * duration) / 2) cursor = 0 # In sample points i = -1 rec_trials = [] @@ -1548,8 +1973,9 @@ def align_clips(self, label, event, video_match, duration=1): trial_starts = np.where(np.bitwise_and(actions, label))[0] behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour'])) + assert 0 behavioural_data = behavioural_data["/'CamFrames'/'0'"] - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array)] = np.squeeze(behav_array) behavioural_data = behavioural_data[:len(behav_array)] trigger = signal.binarise(behavioural_data).values @@ -1570,7 +1996,7 @@ def align_clips(self, label, event, video_match, duration=1): for start in trial_starts: centre = np.where(np.bitwise_and(events[start:start + scan_duration], event))[0] if len(centre) == 0: - print("No event found for an action. If this is OK, ignore this.") + logging.info("\nNo event found for an action. If this is OK, ignore this.") continue centre = start + centre[0] frames = timings.loc[ @@ -1599,17 +2025,32 @@ def align_clips(self, label, event, video_match, duration=1): return trials def get_cluster_info(self): - if self._cluster_info is None: - info_file = self.processed / 'sorted_stream_0' / 'cluster_info.tsv' + for stream_num, info in enumerate(self._cluster_info): + if info is None: + info_file = self.ks_outputs[stream_num] / 'cluster_info.tsv' + try: + info = pd.read_csv(info_file, sep='\t') + except FileNotFoundError: + msg = ": Can't load cluster info. Did you sort this session yet?" + raise PixelsError(self.name + msg) + self._cluster_info[stream_num] = info + return self._cluster_info + + def get_good_units_info(self): + if self._good_unit_info is None: + #az: good_units_info.tsv saved while running depth_profile.py + info_file = self.interim / 'good_units_info.tsv' + #logging.info(f"\n> got good unit info at {info_file}\n") + try: info = pd.read_csv(info_file, sep='\t') except FileNotFoundError: - msg = ": Can't load cluster info. Did you sort this session yet?" + msg = ": Can't load cluster info. Did you export good unit info for this session yet?" raise PixelsError(self.name + msg) - self._cluster_info = info - return self._cluster_info + self._good_unit_info = info + return self._good_unit_info - @_cacheable + #@_cacheable def get_spike_widths(self, units=None): if units: # Always defer to getting widths for all units, so we only ever have to @@ -1617,7 +2058,7 @@ def get_spike_widths(self, units=None): all_widths = self.get_spike_widths() return all_widths.loc[all_widths.unit.isin(units)] - print("Calculating spike widths") + logging.info("\nCalculating spike widths") waveforms = self.get_spike_waveforms() widths = [] @@ -1640,48 +2081,242 @@ def get_spike_widths(self, units=None): df['median_ms'] = 1000 * df['median_ms'] / orig_rate return df - @_cacheable - def get_spike_waveforms(self, units=None): - from phylib.io.model import load_model - from phylib.utils.color import selected_cluster_color + #@_cacheable + def get_spike_waveforms(self, units=None, method='phy'): + """ + Extracts waveforms of spikes. + method: str, name of selected method. + 'phy' (default) + 'spikeinterface' + """ + if method == 'phy': + from phylib.io.model import load_model + from phylib.utils.color import selected_cluster_color + + if units: + # defer to getting waveforms for all units + waveforms = self.get_spike_waveforms()[units] + assert list(waveforms.columns.get_level_values("unit").unique()) == list(units) + return waveforms + + units = self.select_units() + + #paramspy = self.processed / 'sorted_stream_0' / 'params.py' + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! + paramspy = self.ks_outputs / 'params.py' + if not paramspy.exists(): + raise PixelsError(f"{self.name}: params.py not found") + model = load_model(paramspy) + rec_forms = {} + + for u, unit in enumerate(units): + logging.info(f"\n{round(100 * u / len(units), 2)}% complete") + # get the waveforms from only the best channel + spike_ids = model.get_cluster_spikes(unit) + best_chan = model.get_cluster_channels(unit)[0] + u_waveforms = model.get_waveforms(spike_ids, [best_chan]) + if u_waveforms is None: + raise PixelsError(f"{self.name}: unit {unit} - waveforms not read") + rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T) + + assert rec_forms + + df = pd.concat( + rec_forms, + axis=1, + names=['unit', 'spike'] + ) + # convert indexes to ms + rate = 1000 / int(self.spike_meta[0]['imSampRate']) + df.index = df.index * rate + return df + + #TODO: implement spikeinterface waveform extraction + elif method == 'spikeinterface': + ## set chunks + #job_kwargs = dict( + # n_jobs=-3, # -1: num of job equals num of cores + # chunk_duration="1s", + # progress_bar=True, + #) + recording, _ = self.load_recording() + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! + try: + sorting = se.read_kilosort(self.ks_outputs) + except ValueError as e: + raise PixelsError( + f"Can't load sorting object. Did you delete cluster_info.csv? Full error: {e}\n" + ) + + # check last modified time of cache, and create time of ks_output + try: + template_cache_mod_time = os.path.getmtime(self.interim / + 'cache/templates_average.npy') + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! + ks_mod_time = os.path.getmtime(self.ks_outputs / 'cluster_info.tsv') + assert template_cache_mod_time < ks_mod_time + check = True # re-extract waveforms + logging.info("\n> Re-extracting waveforms since kilosort output is newer.") + except: + if 'template_cache_mod_time' in locals(): + logging.info("\n> Loading existing waveforms.") + check = False # load existing waveforms + else: + logging.info("\n> Extracting waveforms since they are not extracted.") + check = True # re-extract waveforms + + """ + # for testing: get first 5 mins of the recording + fs = concat_rec.get_sampling_frequency() + test = concat_rec.frame_slice( + start_frame=0*fs, + end_frame=300*fs, + ) + test.annotate(is_filtered=True) + # check all annotations + test.get_annotation('is_filtered') + logging.info(test) + """ + + # extract waveforms + waveforms = si.extract_waveforms( + recording=recording, + sorting=sorting, + folder=self.interim / 'cache', + load_if_exists=not(check), # maybe re-extracted + max_spikes_per_unit=500, # None will extract all waveforms + ms_before=2.0, # time before trough + ms_after=3.0, # time after trough + overwrite=check, # overwrite depends on check + **job_kwargs, + ) + #TODO: use cache to export the results? - if units: - # defer to getting waveforms for all units - waveforms = self.get_spike_waveforms()[units] - assert list(waveforms.columns.get_level_values("unit").unique()) == list(units) return waveforms - units = self.select_units() - - paramspy = self.processed / 'sorted_stream_0' / 'params.py' - if not paramspy.exists(): - raise PixelsError(f"{self.name}: params.py not found") - model = load_model(paramspy) - rec_forms = {} - - for u, unit in enumerate(units): - print(100 * u / len(units), "% complete") - # get the waveforms from only the best channel - spike_ids = model.get_cluster_spikes(unit) - best_chan = model.get_cluster_channels(unit)[0] - u_waveforms = model.get_waveforms(spike_ids, [best_chan]) - if u_waveforms is None: - raise PixelsError(f"{self.name}: unit {unit} - waveforms not read") - rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T) - - assert rec_forms - - df = pd.concat( - rec_forms, - axis=1, - names=['unit', 'spike'] - ) - # convert indexes to ms - rate = 1000 / int(self.spike_meta[0]['imSampRate']) - df.index = df.index * rate + else: + raise PixelsError(f"{self.name}: waveform extraction method {method} is\ + not implemented!") + + + #@_cacheable + def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): + """ + This func is a work-around of spikeinterface's equivalent: + https://github.com/SpikeInterface/spikeinterface/blob/master/spikeinterface/postprocessing/template_metrics.py. + dec 23rd 2022: motivation to write this function is that spikeinterface + 0.96.1 cannot load sorting object from `export_to_phy` output folder, i.e., i + cannot get updated clusters/units and their waveforms, which is a huge + problem for subsequent analyses e.g. unit type clustering. + + To learn more about waveform metrics, see + https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms + and https://journals.physiology.org/doi/full/10.1152/jn.00680.2018. + + """ + if units: + # Always defer to getting waveform metrics for all good units, so we only + # ever have to calculate metrics for each once. + wave_metrics = self.get_waveform_metrics() + return wave_metrics.loc[wave_metrics.unit.isin(units)] + + # TODO june 2nd 2023: extract amplitude, i.e., abs(trough - peak) in mV + # make sure amplitude is in mV + # normalise these metrics before passing to k-means + columns = ["unit", "duration", "trough_peak_ratio", "half_width", + "repolarisation_slope", "recovery_slope"] + logging.info(f"\n> Calculating waveform metrics {columns[1:]}...\n") + + waveforms = self.get_spike_waveforms() + # remove nan values + waveforms.dropna() + units = waveforms.columns.get_level_values('unit').unique() + + output = {} + for i, unit in enumerate(units): + metrics = [] + #mean_waveform = waveforms[unit].mean(axis=1) + median_waveform = waveforms[unit].median(axis=1) + # normalise mean waveform to remove variance caused by distance! + norm_waveform = median_waveform / median_waveform.abs().max() + #TODO: test! also can try clustering on normalised meann waveform + mean_waveform = norm_waveform + + # time between trough to peak, in ms + trough_idx = np.argmin(mean_waveform) + peak_idx = trough_idx + np.argmax(mean_waveform.iloc[trough_idx:]) + if peak_idx == 0: + raise PixelsError(f"> Cannot find peak in mean waveform.\n") + if trough_idx == 0: + raise PixelsError(f"> Cannot find trough in mean waveform.\n") + duration = mean_waveform.index[peak_idx] - mean_waveform.index[trough_idx] + metrics.append(duration) + + # trough to peak ratio + trough_peak_ratio = mean_waveform.iloc[peak_idx] / mean_waveform.iloc[trough_idx] + metrics.append(trough_peak_ratio) + + # spike half width, in ms + half_amp = mean_waveform.iloc[trough_idx] / 2 + idx_pre_half = np.where(mean_waveform.iloc[:trough_idx] < half_amp) + idx_post_half = np.where(mean_waveform.iloc[trough_idx:] < half_amp) + # last occurence of mean waveform amp lower than half amp, before trough + if len(idx_pre_half[0]) == 0: + idx_pre_half = trough_idx - 1 + time_pre_half = mean_waveform.index[idx_pre_half] + else: + time_pre_half = mean_waveform.iloc[idx_pre_half[0] - 1].index[0] + # first occurence of mean waveform amp lower than half amp, after trough + time_post_half = mean_waveform.iloc[idx_post_half[0] + 1 + + trough_idx].index[-1] + half_width = time_post_half - time_pre_half + metrics.append(half_width) + + # repolarisation slope + returns = np.where(mean_waveform.iloc[trough_idx:] >= 0) + trough_idx + if len(returns[0]) == 0: + logging.info(f"\n> The mean waveformrns never returned to baseline?\n") + return_idx = mean_waveform.shape[0] - 1 + else: + return_idx = returns[0][0] + if return_idx - trough_idx < 2: + raise PixelsError(f"> The mean waveform returns to baseline too quickly,\ + \ndoes not make sense...\n") + repo_period = mean_waveform.iloc[trough_idx:return_idx] + repo_slope = scipy.stats.linregress( + x=repo_period.index.values, # time in ms + y=repo_period.values, # amp + ).slope + metrics.append(repo_slope) + + # recovery slope during user-defined recovery period + recovery_end_idx = peak_idx + window + recovery_end_idx = np.min([recovery_end_idx, mean_waveform.shape[0]]) + reco_period = mean_waveform.iloc[peak_idx:recovery_end_idx] + reco_slope = scipy.stats.linregress( + x=reco_period.index.values, # time in ms + y=reco_period.values, # amp + ).slope + metrics.append(reco_slope) + + # save metrics in output dictionary, key is unit id + output[unit] = metrics + + # save all template metrics as dataframe + df = pd.DataFrame(output).T.reset_index() + df.columns = columns + dtype = {"unit": int} + df = df.astype(dtype) + # see which cols have nan + df.isnull().sum() + return df - @_cacheable + + #@_cacheable def get_aligned_spike_rate_CI( self, label, event, start=0.000, step=0.100, end=1.000, @@ -1783,8 +2418,9 @@ def get_aligned_spike_rate_CI( trial_responses = [] for trial, t_start, t_end in zip(trials, start, end): if not (t_start < t_end): - print( - f"Warning: trial {trial} skipped in CI calculation due to bad timepoints" + logging.warning( + f"\nWarning: trial {trial} skipped in CI calculation" + " due to bad timepoints" ) continue trial_responses.append( @@ -1858,3 +2494,320 @@ def get_aligned_spike_rate_CI( df = pd.concat(cis, axis=1, names=['unit', 'bin']) df.set_index(pd.Index(percentiles, name="percentile"), inplace=True) return df + + + def get_positional_data( + self, label, event, end_event=None, sigma=None, units=None, + normalised=False, + ): + """ + Get positional firing rate of selected units in vr, and spatial + occupancy of each position. + """ + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + output = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting positional neural data of {units} units in " + f"<{label.name}> trials." + ) + + output[stream_id] = stream.get_positional_data( + units=units, # NOTE: put units first! + label=label, + event=event, + sigma=sigma, + end_event=end_event, + normalised=normalised, + ) + + return output + + + def get_binned_trials( + self, label, event, units=None, sigma=None, end_event=None, + time_bin=None, pos_bin=None, + ): + """ + Returns spike rate for each unit within a trial. + align_trials delegates to this function, and should be used for getting aligned + data in scripts. + + This function also saves binned data in the format that Andrew wants: + trials * units * temporal bins (ms) + + time_bin: str | None + For VR behaviour, size of temporal bin for spike rate data. + + pos_bin: int | None + For VR behaviour, size of positional bin for position data. + """ + binned = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting binned <{label.name}> trials from {stream_id} " + f"in {units}." + ) + binned[stream_id] = stream.get_binned_trials( + units=units, # NOTE: always the first arg! + label=label, + event=event, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + ) + + return binned + + + def get_binned_chance( + self, label, event, sigma, end_event, time_bin, pos_bin, + ): + """ + This function saves binned data in the format that Andrew wants: + trials * units * temporal bins (ms) + + time_bin: str | None + For VR behaviour, size of temporal bin for spike rate data. + + pos_bin: int | None + For VR behaviour, size of positional bin for position data. + """ + units = self.select_units(name="all") + + binned = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting binned <{label.name}> chance from {stream_id}." + ) + binned[stream_id] = stream.get_binned_chance( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + ) + + return binned + + + + def get_chance_positional_psd(self, units, label, event, sigma, end_event): + streams = self.files["pixels"] + chance_psd = {} + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + chance_psd[stream_id] = stream.get_chance_positional_psd( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + ) + + return chance_psd + + + def get_spike_chance(self, label, event, sigma, end_event): + """ + Get trials aligned to an event. This finds all instances of label in the action + labels - these are the start times of the trials. Then this finds the first + instance of event on or after these start times of each trial. Then it cuts out + a period around each of these events covering all units, rearranges this data + into a MultiIndex DataFrame and returns it. + + Parameters + ---------- + label : int + An action label value to specify which trial types are desired. + + event : int + An event type value to specify which event to align the trials to. + + units : list of lists of ints, optional + The output from self.select_units, used to only apply this method to a + selection of units. + + end_event : int + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. + """ + units = self.select_units(name="all") + streams = self.files["pixels"] + output = {} + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + output[stream_id] = stream.get_spike_chance( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + + return output + + + def save_spike_chance(self, spiked, sigma, sample_rate): + # TODO apr 21 2025: + # do we put this func here or in stream.py??? + + # save index and columns to reconstruct df for shuffled data + ioutils.save_index_to_frame( + df=spiked, + path=self.interim / stream_files["shuffled_index"], + ) + ioutils.save_cols_to_frame( + df=spiked, + path=self.interim / stream_files["shuffled_columns"], + ) + + # get chance data paths + paths = { + "spiked_memmap_path": self.interim /\ + stream_files["spiked_shuffled_memmap"], + "fr_memmap_path": self.interim /\ + stream_files["fr_shuffled_memmap"], + } + + # save chance data + xut.save_spike_chance( + **paths, + sigma=sigma, + sample_rate=self.SAMPLE_RATE, + spiked=spiked, + ) + + return None + + + def sync_vr(self, vr_session): + """ + Synchronise each pixels stream with virtual reality data. + + params + === + vr: class, virtual reality session object. + """ + streams = self.files["pixels"] + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Synchonising {self.name} {stream_id} pixels data with vr." + ) + stream.sync_vr(vr_session) + + return None + + + def get_spatial_psd( + self, label, event, end_event=None, sigma=None, units=None, + crop_from=None, use_binned=False, time_bin=None, pos_bin=None, + ): + """ + Get spatial power spectral density of selected units. + """ + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + output = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting spatial PSD of {units} units in " + f"<{label.name}> trials." + ) + output[stream_id] = stream.get_spatial_psd( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + crop_from=crop_from, + use_binned=use_binned, + time_bin=time_bin, + pos_bin=pos_bin, + ) + + return output + + + def get_landmark_responsives(self, label, sigma=None, units=None): + output = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting landmark responsive units from {units} in " + f"<{label.name}> trials." + ) + output[stream_id] = stream.get_landmark_responsives( + units=units, + label=label, + sigma=sigma, + ) + + return output diff --git a/pixels/behaviours/passive_viewing.py b/pixels/behaviours/passive_viewing.py new file mode 100644 index 0000000..f107339 --- /dev/null +++ b/pixels/behaviours/passive_viewing.py @@ -0,0 +1,534 @@ +""" +This module provides passive viewing task specific operations. +""" + +from __future__ import annotations + +import pickle + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from reach.session import Outcomes, Targets + +from pixels import Experiment, PixelsError +from pixels import signal, ioutils +from pixels.behaviours import Behaviour + + +class VisStimLabels: + """ + These visual stim labels cover all possible trial types. + 'sf_004' and 'sf_016' correspond to the grating trial's spatial frequency. + This means trials with all temporal frequencies are included here. + 'left' and 'right' notes the direction where the grating is drifting towards. + Gratings have six possible orientations in total, 30 deg aparts. + 'nat_movie' and 'cage_movie' correspond to movie trials, where a 60s of natural + or animal home cage movie is played. + + for ESCN00 to 03, start of each of these trials is marked by a rise of 500 ms TTL + pulse. + Within each grating trial, each orientation grating shows for 2s, their + beginnings and ends are both marked by a rise of 500 ms TTL pulse. + + *dec 16th 2022: in the near future, there will be 'nat_image' trial, where... + + To align trials to more than one action type they can be bitwise OR'd i.e. + `miss_left | miss_right` will match all miss trials. + """ + # gratings 0.04 & 0.16 spatial frequency + sf_004 = 1 << 0 + sf_016 = 1 << 1 + + # directions + left = 1 << 2 + right = 1 << 3 + + # orientations + g_0 = 1 << 4 + g_30 = 1 << 5 + g_60 = 1 << 6 + g_90 = 1 << 7 + g_120 = 1 << 8 + g_150 = 1 << 9 + + # iti marks + black = 1 << 10 + gray = 1 << 11 + + # movies + nat_movie = 1 << 12 + cage_movie = 1 << 13 + + """ + # spatial freq 0.04 directions 0:30:150 + g_004_0 = sf_004 & g_0 & left + g_004_30 = sf_004 & g_30 & left + g_004_60 = sf_004 & g_60 & left + g_004_90 = sf_004 & g_90 & left + g_004_120 = sf_004 & g_120 & left + g_004_150 = sf_004 & g_150 & left + + # spatial freq 0.04 directions 180:30:330 + g_004_180 = sf_004 & g_0 & right + g_004_210 = sf_004 & g_30 & right + g_004_240 = sf_004 & g_60 & right + g_004_270 = sf_004 & g_90 & right + g_004_300 = sf_004 & g_120 & right + g_004_330 = sf_004 & g_150 & right + + # spatial freq 0.16 directions 0:30:150 + g_016_0 = sf_016 & g_0 & left + g_016_30 = sf_016 & g_30 & left + g_016_60 = sf_016 & g_60 & left + g_016_90 = sf_016 & g_90 & left + g_016_120 = sf_016 & g_120 & left + g_016_150 = sf_016 & g_150 & left + + # spatial freq 0.16 directions 180:30:330 + g_016_180 = sf_016 & g_0 & right + g_016_210 = sf_016 & g_30 & right + g_016_240 = sf_016 & g_60 & right + g_016_270 = sf_016 & g_90 & right + g_016_300 = sf_016 & g_120 & right + g_016_330 = sf_016 & g_150 & right + + # spatial freq 0.04 + g_004 = sf_004 & (g_0 | )& (left | right) + g_004 = g_004_0 | g_004_30 | g_004_60 | g_004_90 | + + # spatial freq 0.16 + """ + + #TODO: natural images + +class Events: + start = 1 << 0 + end = 1 << 1 + +# These are used to convert the trial data into Actions and Events +_side_map = { + Targets.LEFT: "left", + Targets.RIGHT: "right", +} + +_action_map = { + Outcomes.MISSED: "miss", + Outcomes.CORRECT: "correct", + Outcomes.INCORRECT: "incorrect", +} + + + +class PassiveViewing(Behaviour): + #TODO: continue here, see how to map behaviour metadata + def _preprocess_behaviour(self, rec_num, behavioural_data): + # Correction for sessions where sync channel interfered with LED channel + if behavioural_data["/'ReachLEDs'/'0'"].min() < -2: + behavioural_data["/'ReachLEDs'/'0'"] = behavioural_data["/'ReachLEDs'/'0'"] \ + + 0.5 * behavioural_data["/'NpxlSync_Signal'/'0'"] + + behavioural_data = signal.binarise(behavioural_data) + action_labels = np.zeros((len(behavioural_data), 2), dtype=np.uint64) + + try: + cue_leds = behavioural_data["/'ReachLEDs'/'0'"].values + except KeyError: + # some early recordings still used this key + cue_leds = behavioural_data["/'Back_Sensor'/'0'"].values + + led_onsets = np.where((cue_leds[:-1] == 0) & (cue_leds[1:] == 1))[0] + led_offsets = np.where((cue_leds[:-1] == 1) & (cue_leds[1:] == 0))[0] + action_labels[led_onsets, 1] += Events.led_on + action_labels[led_offsets, 1] += Events.led_off + metadata = self.metadata[rec_num] + + # QA: Check that the JSON and TDMS data have the same number of trials + if len(led_onsets) != len(metadata["trials"]): + # If they do not have the same number, perhaps the TDMS was stopped too early + meta_onsets = np.array([t["start"] for t in metadata["trials"]]) * 1000 + meta_onsets = (meta_onsets - meta_onsets[0] + led_onsets[0]).astype(int) + if meta_onsets[-1] > len(cue_leds): + # TDMS stopped too early, continue anyway. + i = -1 + while meta_onsets[i] > len(cue_leds): + metadata["trials"].pop() + i -= 1 + assert len(led_onsets) == len(metadata["trials"]) + else: + # If you have come to debug and see why this error was raised, try: + # led_onsets - meta_onsets[:len(led_onsets)] # This might show the problem + # meta_onsets - led_onsets[:len(meta_onsets)] # This might show the problem + # Then just patch a fix here: + if self.name == "211027_VR49" and rec_num == 1: + del metadata["trials"][52] # Maybe cable fell out of DAQ input? + else: + raise PixelsError( + f"{self.name}: Mantis and Raspberry Pi behavioural " + "data have different no. of trials" + ) + + # QA: Last offset not found in tdms data? + if len(led_offsets) < len(led_onsets): + last_trial = self.metadata[rec_num]['trials'][-1] + if "end" in last_trial: + # Take known offset from metadata + offset = led_onsets[-1] + (last_trial['end'] - last_trial['start']) * 1000 + led_offsets = np.append(led_offsets, int(offset)) + else: + # If not possible, just remove last onset + led_onsets = led_onsets[:-1] + metadata["trials"].pop() + assert len(led_offsets) == len(led_onsets) + + # QA: For some reason, sometimes the final trial metadata doesn't include the + # final led-off even though it is detectable in the TDMS data. + elif len(led_offsets) == len(led_onsets): + # Not sure how to deal with this if led_offsets and led_onsets differ in length + if len(metadata["trials"][-1]) == 1 and "start" in metadata["trials"][-1]: + # Remove it, because we would have to check the video to get all of the + # information about the trial, and it's too complicated. + metadata["trials"].pop() + led_onsets = led_onsets[:-1] + led_offsets = led_offsets[:-1] + + # QA: Check that the cue durations (mostly) match between JSON and TDMS data + # This compares them at 10s of milliseconds resolution + cue_durations_tdms = (led_offsets - led_onsets) / 100 + cue_durations_json = np.array( + [t['end'] - t['start'] for t in metadata['trials']] + ) * 10 + error = sum( + (cue_durations_tdms - cue_durations_json).round() != 0 + ) / len(led_onsets) + if error > 0.05: + raise PixelsError( + f"{self.name}: Mantis and Raspberry Pi behavioural data have mismatching trial data." + ) + + return behavioural_data, action_labels, led_onsets + + def _extract_action_labels(self, rec_num, behavioural_data, plot=False): + behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(rec_num, behavioural_data) + + for i, trial in enumerate(self.metadata[rec_num]["trials"]): + side = _side_map[trial["spout"]] + outcome = trial["outcome"] + if outcome in _action_map: + action = _action_map[trial["outcome"]] + action_labels[led_onsets[i], 0] += getattr(ActionLabels, f"{action}_{side}") + + if plot: + plt.clf() + _, axes = plt.subplots(4, 1, sharex=True, sharey=True) + axes[0].plot(back_sensor_signal) + if "/'Back_Sensor'/'0'" in behavioural_data: + axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) + else: + axes[1].plot(behavioural_data["/'ReachCue_LEDs'/'0'"].values) + axes[2].plot(action_labels[:, 0]) + axes[3].plot(action_labels[:, 1]) + plt.plot(action_labels[:, 1]) + plt.show() + + return action_labels + + def draw_slit_thresholds(self, project: str, force: bool = False): + """ + Draw lines on the slits using EasyROI. If ROIs already exist, skip. + + Parameters + ========== + project : str + The DLC project i.e. name/prefix of the camera. + + force : bool + If true, we will draw new lines even if the output file exists. + + """ + # Only needed for this method + import cv2 + import EasyROI + + output = self.processed / f"slit_thresholds_{project}.pickle" + + if output.exists() and not force: + print(self.name, "- slits drawn already.") + return + + # Let's take the average between the first and last frames of the whole session. + videos = [] + + for recording in self.files: + for v, video in enumerate(recording.get("camera_data", [])): + if project in video.stem: + avi = self.interim / video.with_suffix('.avi') + if not avi.exists(): + meta = recording['camera_meta'][v] + ioutils.tdms_to_video( + self.find_file(video, copy=False), + self.find_file(meta), + avi, + ) + if not avi.exists(): + raise PixelsError(f"Path {avi} should exist but doesn't... discuss.") + videos.append(avi.as_posix()) + + if not videos: + raise PixelsError("No videos were found to draw slits on.") + + first_frame = ioutils.load_video_frame(videos[0], 1) + last_duration = ioutils.get_video_dimensions(videos[-1])[2] + last_frame = ioutils.load_video_frame(videos[-1], last_duration - 1) + + average_frame = np.concatenate( + [first_frame[..., None], last_frame[..., None]], + axis=2, + ).mean(axis=2) + average_frame = np.squeeze(average_frame) / 255 + + # Interactively draw ROI + global _roi_helper + if _roi_helper is None: + # Ugly but we can only have one instance of this + _roi_helper = EasyROI.EasyROI(verbose=False) + lines = _roi_helper.draw_line(average_frame, 2) + cv2.destroyAllWindows() # Needed otherwise EasyROI errors + + # Save a copy of the frame with ROIs to PNG file + png = output.with_suffix(".png") + copy = EasyROI.visualize_line(average_frame, lines, color=(255, 0, 0)) + plt.imsave(png, copy, cmap='gray') + + # Save lines to file + with output.open('wb') as fd: + pickle.dump(lines['roi'], fd) + + def inject_slit_crossings(self): + """ + Take the lines drawn from `draw_slit_thresholds` above, get the reach + coordinates from DLC output, identify the timepoints when successful reaches + crossed the lines, and add the `reach_onset` event to those timepoints in the + action labels. Also identify which trials need clearing up, i.e. those with + multiple reaches or have failed motion tracking, and exclude those. + """ + lines = {} + projects = ("LeftCam", "RightCam") + + for project in projects: + line_file = self.processed / f"slit_thresholds_{project}.pickle" + + if not line_file.exists(): + print(self.name, "- Lines not drawn for session.") + return + + with line_file.open("rb") as f: + proj_lines = pickle.load(f) + + lines[project] = { + tt:[pd.Series(p, index=["x", "y"]) for p in points.values()] + for tt, points in proj_lines.items() + } + + action_labels = self.get_action_labels() + event = Events.led_off + + # https://bryceboe.com/2006/10/23/line-segment-intersection-algorithm + def ccw(A, B, C): + return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) + + for tt, action in enumerate( + (ActionLabels.correct_left, ActionLabels.correct_right), + ): + data = {} + trajectories = {} + + for project in projects: + proj_data = self.align_trials( + action, + event, + "motion_tracking", + duration=6, + dlc_project=project, + ) + proj_traj, = check_scorers(proj_data) + data[project] = proj_traj + trajectories[project] = get_reach_trajectories(proj_traj)[0] + + for rec_num, recording in enumerate(self.files): + actions = action_labels[rec_num][:, 0] + events = action_labels[rec_num][:, 1] + trial_starts = np.where(np.bitwise_and(actions, action))[0] + + for t, start in enumerate(trial_starts): + centre = np.where(np.bitwise_and(events[start:start + 6000], event))[0] + if len(centre) == 0: + raise PixelsError('Action labels probably miscalculated') + centre = start + centre[0] + # centre is index of this rec's grasp + onsets = [] + for project, motion in trajectories.items(): + left_hand = motion["left_hand_median"][t][:0] + right_hand = motion["right_hand_median"][t][:0] + pt1, pt2 = lines[project][tt] + + x_l = left_hand.iloc[-10:]["x"].mean() + x_r = right_hand.iloc[-10:]["x"].mean() + hand = right_hand + if project == "LeftCam": + if x_l > x_r: + hand = left_hand + else: + if x_r > x_l: + hand = left_hand + + segments = zip( + hand.iloc[::-1].iterrows(), + hand.iloc[-2::-1].iterrows(), + ) + + for (end, ptend), (start, ptsta) in segments: + if ( + ccw(pt1, ptsta, ptend) != ccw(pt2, ptsta, ptend) and + ccw(pt1, pt2, ptsta) != ccw(pt1, pt2, ptend) + ): + # These lines intersect + #print("x from ", ptsta.x, " to ", ptend.x) + break + #if ptend.y > 300 or ptsta.y > 300: + # assert 0 + onsets.append(start) + + onset = max(onsets) + onset_timepoint = round(centre + (onset * 1000)) + events[onset_timepoint] |= Events.reach_onset + + output = self.processed / recording['action_labels'] + np.save(output, action_labels[rec_num]) + + +global _roi_helper +_roi_helper = None + + +class VisualOnly(Reach): + def _extract_action_labels(self, behavioural_data, plot=False): + behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(behavioural_data) + + for i, trial in enumerate(self.metadata["trials"]): + label = "naive_" + _side_map[trial["spout"]] + "_" + if trial["cue_duration"] > 125: + label += "long" + else: + label += "short" + action_labels[led_onsets[i], 0] += getattr(ActionLabels, label) + + return action_labels + + +def get_reach_velocities(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Get the velocity curves for the provided reach trajectories. + """ + results = [] + + for df in dfs: + df = df.copy() + deltas = np.square(df.iloc[1:].values - df.iloc[:-1].values) + # Fill the start with a row of zeros - each value is delta in previous 1 ms + deltas = np.append(np.zeros((1, deltas.shape[1])), deltas, axis=0) + deltas = np.sqrt(deltas[:, ::2] + deltas[:, 1::2]) + df = df.drop([c for c in df.columns if "y" in c], axis=1) + df = df.rename({"x": "delta"}, axis='columns') + df.values[:] = deltas + results.append(df) + + return tuple(results) + + +def get_reach_trajectories(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Get the median centre point of the hand coordinates - i.e. for the labels for each + of the four digits and hand centre for both paws. + """ + assert dfs + bodyparts = get_body_parts(dfs[0]) + right_paw = [p for p in bodyparts if p.startswith("right")] + left_paw = [p for p in bodyparts if p.startswith("left")] + + trajectories_l = [] + trajectories_r = [] + + for df in dfs: + per_ses_l = [] + per_ses_r = [] + + # Ugly hack so this function works on single sessions + if "session" not in df.columns.names: + df = pd.concat([df], keys=[0], axis=1) + sessions = [0] + single_session = True + else: + single_session = False + sessions = df.columns.get_level_values("session").unique() + + for s in sessions: + per_trial_l = [] + per_trial_r = [] + + trials = df[s].columns.get_level_values("trial").unique() + for t in trials: + tdf = df[s][t] + left = pd.concat((tdf[p] for p in left_paw), keys=left_paw) + right = pd.concat((tdf[p] for p in right_paw), keys=right_paw) + per_trial_l.append(left.groupby(level=1).median()) + per_trial_r.append(right.groupby(level=1).median()) + + per_ses_l.append(pd.concat(per_trial_l, axis=1, keys=trials)) + per_ses_r.append(pd.concat(per_trial_r, axis=1, keys=trials)) + + trajectories_l.append(pd.concat(per_ses_l, axis=1, keys=sessions)) + trajectories_r.append(pd.concat(per_ses_r, axis=1, keys=sessions)) + + if single_session: + return tuple( + pd.concat( + [trajectories_l[i][0], trajectories_r[i][0]], + axis=1, + keys=["left_hand_median", "right_hand_median"], + ) + for i in range(len(dfs)) + ) + return tuple( + pd.concat( + [trajectories_l[i], trajectories_r[i]], + axis=1, + keys=["left_hand_median", "right_hand_median"], + ) + for i in range(len(dfs)) + ) + + +def check_scorers(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Checks that the scorers are identical for all data in the dataframes. These are the + dataframes as returned from Exp.align_trials for motion_tracking data. + + It returns the dataframes with the scorer index level removed. + """ + scorers = set( + s + for df in dfs + for s in df.columns.get_level_values("scorer").unique() + ) + + assert len(scorers) == 1, scorers + return tuple(df.droplevel("scorer", axis=1) for df in dfs) + + +def get_body_parts(df: pd.DataFrame) -> list[str]: + """ + Get the list of body part labels. + """ + return df.columns.get_level_values("bodyparts").unique() diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py new file mode 100644 index 0000000..45e8501 --- /dev/null +++ b/pixels/behaviours/virtual_reality.py @@ -0,0 +1,661 @@ +""" +This module provides reach task specific operations. +""" + +# NOTE: for event alignment, we align to the first timepoint when the event +# starts, and the last timepoint before the event ends, i.e., think of an event +# as a train of 0s and 1s, we align to the first 1s and the last 1s of a given +# event, except for licks since it could only be on or off per frame. + +from enum import IntFlag, auto +from typing import NamedTuple + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from vision_in_darkness.base import Outcomes, Worlds, Conditions + +from pixels import PixelsError +from pixels.behaviours import Behaviour +from pixels.configs import * +from pixels.constants import SAMPLE_RATE, V1_SPIKE_LATENCY + + +class TrialTypes(IntFlag): + """ + These cover all possible trial types. + + To align trials to more than one action type they can be bitwise OR'd i.e. + `miss_light | miss_dark` will match all miss trials. + + Trial types can NOT be added on top of each other, they should be mutually + exclusive. + """ + # TODO jun 7 2024 does the name "action" make sense? + + # TODO jul 4 2024 only label trial type at the first frame of the trial to + # make it easier for alignment??? + # triggered vr trials + NONE = 0 + miss_light = auto()#1 << 0 # 1 + miss_dark = auto()#1 << 1 # 2 + triggered_light = auto()#1 << 2 # 4 + triggered_dark = auto()#1 << 3 # 8 + punished_light = auto()#1 << 4 # 16 + punished_dark = auto()#1 << 5 # 32 + + # given reward + default_light = auto()#1 << 6 # 64 + auto_light = auto()#1 << 7 # 128 + auto_dark = auto()#1 << 8 # 256 + reinf_light = auto()#1 << 9 # 512 + reinf_dark = auto()#1 << 10 # 1024 + + # combos + miss = miss_light | miss_dark + triggered = triggered_light | triggered_dark + punished = punished_light | punished_dark + + # trial type combos + light = miss_light | triggered_light | punished_light | default_light\ + | auto_light | reinf_light + dark = miss_dark | triggered_dark | punished_dark | auto_dark | reinf_dark + rewarded_light = triggered_light | default_light | auto_light | reinf_light + rewarded_dark = triggered_dark | auto_dark | reinf_dark + given_light = default_light | auto_light | reinf_light + given_dark = auto_dark | reinf_dark + # no punishment, i.e., no missing data + completed_light = miss_light | triggered_light | default_light\ + | auto_light | reinf_light + completed_dark = miss_dark | triggered_dark | auto_dark | reinf_dark + + +class Events(IntFlag): + """ + Defines events that could happen during vr sessions. + + Events can be added on top of each other. + """ + NONE = 0 + # vr events + trial_start = auto() # 1 + gray_on = auto() # 2 + gray_off = auto() # 4 + light_on = auto() # 8 + light_off = auto() # 16 + dark_on = auto() # 32 + dark_off = auto() # 64 + punish_on = auto() # 128 + punish_off = auto() # 256 + trial_end = auto() # 512 + + # positional events + pre_dark_end = auto()# 50 cm + + # black wall + landmark0_on = auto()# wherever trial starts, before 60cm + landmark0_off = auto()# 60 cm + + # NOTE: wall between landmarks always remove 10cm adjacent + wall1_on = auto()# 70cm + wall1_off = auto()# 100cm + + landmark1_on = auto()# 110 cm + landmark1_off = auto()# 130 cm + + wall2_on = auto()# 140cm + wall2_off = auto()# 180cm + + landmark2_on = auto()# 190 cm + landmark2_off = auto()# 210 cm + + wall3_on = auto()# 220cm + wall3_off = auto()# 260cm + + landmark3_on = auto()# 270 cm + landmark3_off = auto()# 290 cm + + wall4_on = auto()# 300cm + wall4_off = auto()# 340cm + + landmark4_on = auto()# 350 cm + landmark4_off = auto()# 370 cm + + wall5_on = auto()# 380cm + wall5_off = auto()# 420cm + + landmark5_on = auto()# 430 cm + landmark5_off = auto()# 450 cm + + reward_zone_on = auto()# 460 cm + reward_zone_off = auto()# 495 cm + + # sensors + valve_open = auto() + valve_closed = auto() + licked = auto() + #run_start = auto() + #run_stop = auto() + + # temporal events in dark only + dark_luminance_off = auto()# 500 ms + + +class LabeledEvents(NamedTuple): + """Return type: timestamps + bitfields for outcome & events.""" + timestamps: np.ndarray # shape (N,) + outcome: np.ndarray # shape (N,) dtype uint64 + events: np.ndarray # shape (N,) dtype uint64 + + +class WorldMasks(NamedTuple): + in_gray: pd.Series + in_dark: pd.Series + in_white: pd.Series + in_light: pd.Series + in_tunnel: pd.Series + + +class ConditionMasks(NamedTuple): + light_trials: pd.Series + dark_trials: pd.Series + + +class VR(Behaviour): + """Behaviour subclass to extract action labels & events from vr_data.""" + def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: + # define in gray + in_gray = (df.world_index == Worlds.GRAY) + # define in dark + in_dark = df.world_index.isin( + {w.value for w in Worlds if w.is_dark} + ) + # define in white + in_white = (df.world_index == Worlds.WHITE) + # define in light + in_light = (df.world_index == Worlds.TUNNEL) + # define in tunnel + in_tunnel = (~in_gray & ~in_white) + + return WorldMasks( + in_gray = in_gray, + in_dark = in_dark, + in_white = in_white, + in_light = in_light, + in_tunnel = in_tunnel, + ) + + + def _get_condition_masks(self, df: pd.DataFrame) -> ConditionMasks: + return ConditionMasks( + light_trials = (df.trial_type == Conditions.LIGHT), + dark_trials = (df.trial_type == Conditions.DARK), + ) + + + def _extract_action_labels( + self, + session, + data: pd.DataFrame + ) -> LabeledEvents: + """ + Go over every frame in data and assign: + - `events_arr[i]` := bitmask of Events that occur at frame i + - `outcomes_arr[i]` := the trial‐outcome (one and only one TrialType) + at i + """ + N = len(data) + events_arr = np.zeros(N, dtype=np.uint64) + outcomes_arr = np.zeros(N, dtype=np.uint64) + + # world index based events + world_index_based_events = self._world_event_indices(data) + for event, idx in world_index_based_events.items(): + mask = self._get_index(data, idx.to_numpy()) + self._stamp_mask(events_arr, mask, np.uint64(event.value)) + + # get dark onset times for pre_dark_len labels + dark_on_t = world_index_based_events[Events.dark_on] + + # positional events (pre‐dark end, landmarks, reward‐zone) + pos_based_events = self._position_event_indices( + session, + data, + dark_on_t, + ) + for event, idx in pos_based_events.items(): + mask = self._get_index(data, idx.to_numpy()) + self._stamp_mask(events_arr, mask, np.uint64(event.value)) + + # sensors: lick rising‐edge + self._stamp_rising( + events_arr, + data.lick_detect.values, + np.uint64(Events.licked), + ) + + # map trial outcomes + outcome_map = self._build_outcome_map() + for trial_id, group in data.groupby("trial_count"): + trial_t = group.index.values + flag, valve_events = self._compute_outcome_flag( + trial_id, + group, + outcome_map, + session, + ) + + # save outcomes + idx = self._get_index(data, trial_t) + outcomes_arr[idx] = flag + + if len(valve_events) > 0: + for v_event, v_idx in valve_events.items(): + valve_mask = self._get_index(data, v_idx) + self._stamp_mask(events_arr, valve_mask, np.uint64(v_event)) + + # return typed arrays + return LabeledEvents( + timestamps = data.index.values, + outcome = outcomes_arr, + events = events_arr, + ) + + + # ------------------------------------------------------------------------- + # Core stamping helpers + # ------------------------------------------------------------------------- + + @staticmethod + def _stamp_mask(array: np.ndarray, mask: np.ndarray, flag: IntFlag): + """ + Bitwise‐OR `flag` into `storage` at every True in `mask`. + """ + np.bitwise_or.at(array, mask, flag) + + @staticmethod + def _stamp_rising(array: np.ndarray, signal: np.ndarray, flag: IntFlag): + """ + Find rising‐edge frames in a 0/1 `signal` array (diff == +1) + and stamp `flag` at those indices. + """ + # extract edges + edges = np.flatnonzero(np.diff(signal, prepend=0) == 1) + np.bitwise_or.at(array, edges, flag) + + + # ------------------------------------------------------------------------- + # Build world‐based event masks + # ------------------------------------------------------------------------- + + def _world_event_indices( + self, + df: pd.DataFrame + ) -> dict[Events, pd.Series]: + """ + Build (event_flag, boolean_mask) for gray, white (punish), + light tunnel, and dark tunnels, but *per trial*, + so each trial contributes its own on/off. + """ + masks: dict[Events, pd.Series] = {} + N = len(df) + + world_masks = self._get_world_masks(df) + + specs = [ + # (which‐world‐test, on‐event, off‐event) + (world_masks.in_gray, Events.gray_on, Events.gray_off), + (world_masks.in_white, Events.punish_on, Events.punish_off), + (world_masks.in_dark, Events.dark_on, Events.dark_off), + ] + + for bool_mask, event_on, event_off in specs: + # compute the per‐trial first‐index + trials = df[bool_mask].groupby("trial_count") + masks[event_on] = trials.apply(self._first_index) + # compute the per‐trial last‐index + masks[event_off] = trials.apply(self._last_index) + + + # >>>> trial ends >>>> + gray_on_t = masks[Events.gray_on] + + # for non punished trials, right before gray on is when trial ends, plus + # the last frame of the session + trial_ends_t = gray_on_t.copy() + trial_ends_t.iloc[:-1] = (gray_on_t[1:] - 1).to_numpy() + trial_ends_t.iloc[-1] = df.index[-1] + + # trial ends right before punishment starts + punish_on_t = masks[Events.punish_on] + trial_ends_t.loc[punish_on_t.index] = punish_on_t - 1 + + masks[Events.trial_end] = trial_ends_t + # <<<< trial ends <<<< + + # >>> handle light on separately >>> + # build a run id that increments whenever in_light flips + run_id = world_masks.in_light.ne( + world_masks.in_light.shift(fill_value=False) + ).cumsum() + + # restrict to just the True‐runs + light_runs = df[world_masks.in_light].copy() + light_runs['run_id'] = run_id[world_masks.in_light] + + # for each (trial, run_id) get the on‐times and off‐times + firsts = ( + light_runs + .groupby(["trial_count", "run_id"]) + .apply(self._first_index) + ) + light_ons = firsts.droplevel("run_id") + lasts = ( + light_runs + .groupby(["trial_count", "run_id"]) + .apply(self._last_index) + ) + light_offs = lasts.droplevel("run_id") + + masks[Events.light_on] = light_ons + masks[Events.light_off] = light_offs + # <<< handle light on separately <<< + + return masks + + + # ------------------------------------------------------------------------- + # Build positional event masks (landmarks, pre_dark_end, reward zone) + # ------------------------------------------------------------------------- + + def _position_event_indices( + self, + session, + df: pd.DataFrame, + dark_on_t, + ) -> dict[Events, pd.Series]: + masks: dict[Events, pd.Series] = {} + + in_tunnel = self._get_world_masks(df).in_tunnel + + def _first_post_mark(group_df, check_marks): + if isinstance(check_marks, pd.Series): + group_id = group_df.name + mark = check_marks.loc[group_id] + elif isinstance(check_marks, int): + mark = check_marks + + # mask and pick the first index + mask = (group_df["position_in_tunnel"] >= mark) + if not mask.any(): + return None + return group_df.index[mask].min() + + # >>> distance travelled before dark onset per trial >>> + # NOTE: this applies to light trials too to keep data symetrical + # AL remove pre_dark_len + 10cm in all his data + in_tunnel_trials = df[in_tunnel].groupby("trial_count") + # get start of each trial + trial_starts = in_tunnel_trials.apply(self._first_index) + masks[Events.trial_start] = trial_starts + + # NOTE: dark trials should in theory have EQUAL index pre_dark_end_t + # and dark_on, BUT! after interpolation, some dark trials have their + # dark onset earlier than expected, those are corrected to the first + # expected position, they will have the same index as pre_dark_end_t. + # others will not since their world_index change later than expected. + # SO! to keep it consistent, for dark trials, pre_dark_end will be the + # SAME frame as dark onsets. + + # get starting positions of all trials + start_pos = in_tunnel_trials["position_in_tunnel"].first() + + # get light trials + lights = self._get_condition_masks(df).light_trials + light_trials = df[in_tunnel & lights].groupby("trial_count") + # get starting positions of light trials plus pre dark length + light_pre_dark_len = light_trials["position_in_tunnel"].first()\ + + session.pre_dark_len + light_pre_dark_end_t = light_trials.apply( + lambda df: _first_post_mark(df, light_pre_dark_len) + ).dropna().astype(int) + + # concat dark and light trials + pre_dark_end_t = pd.concat([dark_on_t, light_pre_dark_end_t]) + masks[Events.pre_dark_end] = pre_dark_end_t + # >>> distance travelled before dark onset per trial >>> + + # >>> end of luminance change after dark onset >>> + n_frames = int(V1_SPIKE_LATENCY * SAMPLE_RATE / 1000) + dark_luminance_off_t = dark_on_t + n_frames + masks[Events.dark_luminance_off] = dark_luminance_off_t + # <<< end of luminance change after dark onset <<< + + # >>> landmark 0 black wall >>> + black_off = session.landmarks[0] + + starts_before_black = (start_pos < black_off) + early_ids = start_pos[starts_before_black].index + early_trials = df[ + in_tunnel & df.trial_count.isin(early_ids) + ].groupby("trial_count") + + # first frame of black wall + landmark0_on = early_trials.apply( + self._first_index + ) + masks[Events.landmark0_on] = landmark0_on + + # last frame of black wall + landmark0_off = early_trials.apply( + lambda df: _first_post_mark(df, black_off) + ) + masks[Events.landmark0_off] = landmark0_off + # <<< landmark 0 black wall <<< + + # >>> landmarks and wall 1 to 5 >>> + landmarks = session.landmarks[1:] + # get walls, excluding adjacent 10cm to landmark + walls = session.mid_walls + + for l, landmark in enumerate(landmarks): + if l % 2 != 0: + continue + + landmark_idx = l // 2 + 1 + + # NOTE: both landmark and wall boolean use < not <= because + # in our vr data we only get the exact value (e.g., 140.0) at + # starting position. + # even idx on, odd idx off + on_landmark = landmark + on_landmarks = ( + (df.position_in_tunnel >= on_landmark) & + (df.position_in_tunnel < on_landmark + 1) + ) + landmark_on = df[on_landmarks].groupby("trial_count").apply( + self._first_index + ) + + off_landmark = landmarks[l + 1] + off_landmarks = ( + (df.position_in_tunnel > off_landmark - 1) & + (df.position_in_tunnel < off_landmark) + ) + landmark_off = df[off_landmarks].groupby("trial_count").apply( + self._last_index + ) + masks[getattr(Events, f"landmark{landmark_idx}_on")] = landmark_on + masks[getattr(Events, f"landmark{landmark_idx}_off")] = landmark_off + + # even idx on, odd idx off + on_wall = walls[l] + on_walls = ( + (df.position_in_tunnel >= on_wall) & + (df.position_in_tunnel < on_wall + 1) + ) + wall_on = df[on_walls].groupby("trial_count").apply( + self._first_index + ) + + off_wall = walls[l + 1] + off_walls = ( + (df.position_in_tunnel > off_wall - 1) & + (df.position_in_tunnel < off_wall) + ) + wall_off = df[off_walls].groupby("trial_count").apply( + self._last_index + ) + + masks[getattr(Events, f"wall{landmark_idx}_on")] = wall_on + masks[getattr(Events, f"wall{landmark_idx}_off")] = wall_off + # <<< landmarks and wall 1 to 5 <<< + + # >>> reward zone >>> + zone_ons = ( + df.position_in_tunnel >= session.reward_zone_start + ) & ( + df.position_in_tunnel < session.reward_zone_start + 1 + ) + # first frame in reward zone + zone_on_t = df[zone_ons].groupby("trial_count").apply( + self._first_index + ) + + zone_offs = ( + df.position_in_tunnel > session.reward_zone_end - 1 + ) & ( + df.position_in_tunnel < session.reward_zone_end + ) + # last frame in reward zone + zone_off_t = df[zone_offs].groupby("trial_count").apply( + self._last_index + ) + + masks[Events.reward_zone_on] = zone_on_t + masks[Events.reward_zone_off] = zone_off_t + # <<< reward zone <<< + + return masks + + + # ------------------------------------------------------------------------- + # Run‐start / run‐end utilities for boolean masks + # ------------------------------------------------------------------------- + + def _first_index(self, group: pd.DataFrame) -> int: + idx = group.index + early_idx = idx[:len(idx) // 2] + + # double check the last index is the first time reaching that point + idx_discontinued = (np.diff(early_idx) > 1) + #if np.any(idx_discontinued): + # print("discontinued") + # assert 0 + # last_disc = np.where(idx_discontinued)[0][-1] + # return group.iloc[last_disc:].index.min() + #else: + return group.index.min() + + + def _last_index(self, group: pd.DataFrame) -> int: + # only check the second half in case the discontinuity happened at the + # beginning + idx = group.index + late_idx = idx[-len(idx)//4:] + + # double check the last index is the first time reaching that point + idx_discontinued = (np.diff(late_idx) > 1) + if np.any(idx_discontinued): + logging.warning("\n> index discontinued.") + print(group) + first_disc = np.where(idx_discontinued)[0][0] + return late_idx[first_disc] + else: + return idx.max() + + + def _get_index(self, df: pd.DataFrame, index) -> int: + return df.index.get_indexer(index) + + # ------------------------------------------------------------------------- + # Outcome mapping + # ------------------------------------------------------------------------- + + def _build_outcome_map(self) -> dict: + m = { + (Outcomes.ABORTED_LIGHT, Conditions.LIGHT): TrialTypes.miss_light, + (Outcomes.ABORTED_DARK, Conditions.DARK): TrialTypes.miss_dark, + + (Outcomes.NONE, Conditions.LIGHT): TrialTypes.punished_light, + (Outcomes.NONE, Conditions.DARK): TrialTypes.punished_dark, + + (Outcomes.DEFAULT, Conditions.LIGHT): TrialTypes.default_light, + + (Outcomes.AUTO_LIGHT, Conditions.LIGHT): TrialTypes.auto_light, + (Outcomes.AUTO_DARK, Conditions.DARK): TrialTypes.auto_dark, + + (Outcomes.REINF_LIGHT, Conditions.LIGHT): TrialTypes.reinf_light, + (Outcomes.REINF_DARK, Conditions.DARK): TrialTypes.reinf_dark, + + (Outcomes.TRIGGERED, Conditions.LIGHT): TrialTypes.triggered_light, + (Outcomes.TRIGGERED, Conditions.DARK): TrialTypes.triggered_dark, + } + return m + + def _compute_outcome_flag( + self, + trial_id: int, + trial_df: pd.DataFrame, + outcome_map: dict, + session, + ) -> TrialTypes: + + valve_events = {} + + # get non-zero reward types + reward_not_none = (trial_df.reward_type != Outcomes.NONE) + reward_typed = trial_df[reward_not_none] + + # get trial type + trial_type = int(trial_df.trial_type.iloc[0]) + + # get punished + punished = (trial_df.world_index == Worlds.WHITE) + + if (reward_typed.size == 0) & (not np.any(punished)): + # >>>> unfinished trial >>>> + # double check it is the last trial + assert (trial_df.position_in_tunnel.max()\ + < session.tunnel_reset) + logging.info(f"\n> trial {trial_id} is unfinished when session " + "ends, so there is no outcome.") + return TrialTypes.NONE, valve_events + # <<<< unfinished trial <<<< + elif (reward_typed.size == 0) & np.any(punished): + logging.info(f"\n> trial {trial_id} is punished.") + # get reward type zero for punished + reward_type = int(trial_df.reward_type.unique()) + else: + # get non-zero reward type in current trial + reward_type = int(reward_typed.reward_type.unique()) + + if reward_type > Outcomes.NONE: + # >>>> non aborted, valve events >>>> + # if not aborted, map valve open & closed + # map valve open + valve_open_t = reward_typed.index[0] + valve_events[Events.valve_open] = [valve_open_t] + # map valve closed + valve_closed_t = reward_typed.index[-1] + valve_events[Events.valve_closed] = [valve_closed_t] + # <<<< non aborted, valve events <<<< + + # build key for outcome_map + key = (reward_type, trial_type) + + try: + return outcome_map[key], valve_events + except KeyError: + raise PixelsError(f"No mapping for outcome {key}") diff --git a/pixels/configs.py b/pixels/configs.py new file mode 100644 index 0000000..934019d --- /dev/null +++ b/pixels/configs.py @@ -0,0 +1,49 @@ +import logging + +from wavpack_numcodecs import WavPack +from numcodecs import Blosc +import spikeinterface as si + +# Configure logging to include a timestamp with seconds +logging.basicConfig( + level=logging.INFO, + format='''\n%(asctime)s %(levelname)s: %(message)s\ + \n[in %(filename)s:%(lineno)d]''', + datefmt='%Y%m%d %H:%M:%S', +) + +#logging.info('This is an info message.') +#logging.warning('This is a warning message.') +#logging.error('This is an error message.') + +# set si job_kwargs +job_kwargs = dict( + pool_engine="thread", # instead of default "process" + #pool_engine="process",# does not work on 2025 oct 14 + mp_context="fork", # linux + #mp_context="spawn", # mac & win + progress_bar=True, + n_jobs=0.8, + chunk_duration='1s', + max_threads_per_worker=8, +) +si.set_global_job_kwargs(**job_kwargs) + +# instantiate WavPack compressor +wv_compressor = WavPack( + level=3, # high compression + bps=None, # lossless +) + +# use blosc compressor for generic zarr +compressor = Blosc( + cname="zstd", + clevel=5, + shuffle=Blosc.BITSHUFFLE, +) + +# kilosort 4 singularity image names +ks4_0_30_image_name = "si103.0_ks4-0-30_with_wavpack.sif" +#ks4_0_30_image_name = "si102.3_ks4-0-30_with_wavpack.sif" +ks4_0_18_image_name = "ks4-0-18_with_wavpack.sif" +ks4_image_name = ks4_0_30_image_name diff --git a/pixels/constants.py b/pixels/constants.py new file mode 100644 index 0000000..004b2b6 --- /dev/null +++ b/pixels/constants.py @@ -0,0 +1,36 @@ +""" +This file contains some constants parameters for the pixels pipeline. +""" +import numpy as np + +SAMPLE_RATE = 2000 # Hz + +freq_bands = { + "ap":[300, 9000], + "lfp":[0.5, 500], + "theta":[4, 11], # from Tom + "gamma":[30, 80], # from Tom + "ripple":[110, 220], # from Tom +} + +BEHAVIOUR_HZ = 25000 + +np.random.seed(BEHAVIOUR_HZ) + +REPEATS = 100 + +# latency of luminance change evoked spike +# NOTE: usually it should be between 40-60ms. now we use 500ms just to test +V1_SPIKE_LATENCY = 500 # ms #60 +V1_LFP_LATENCY = 40 # ms +LGN_SPIKE_LATENCY = 30 # ms + +# chunking for zarr +SMALL_CHUNKS = 64 +BIG_CHUNKS = 1024 + +ALPHA = 0.05 + +# position bin sizes +POSITION_BIN = 1 # cm +BIG_POSITION_BIN = 10 # cm diff --git a/pixels/decorators.py b/pixels/decorators.py new file mode 100644 index 0000000..63cd83a --- /dev/null +++ b/pixels/decorators.py @@ -0,0 +1,712 @@ +# annotations not evaluated at runtime +from __future__ import annotations + +import shutil +from pathlib import Path +from functools import wraps +from typing import Any +import inspect + +import numpy as np +import pandas as pd +from tables import HDF5ExtError +try: + import zarr + from numcodecs import Blosc +except Exception: + zarr = None + Blosc = None + +try: + import xarray as xr +except Exception: + xr = None + +from pixels.configs import * +from pixels.constants import * +from pixels import ioutils +from pixels.error import PixelsError +from pixels.units import SelectedUnits + + +def _safe_key(s: str) -> str: + return str(s).replace("/", "_").replace(".", "_") + + +def _make_default_compressor() -> Any: + if Blosc is None: + return None + return Blosc(cname="zstd", clevel=5, shuffle=Blosc.BITSHUFFLE) + +# --------------------------- +# xarray <-> DataFrame helpers +# --------------------------- + +def _default_names(names: list[str | None], prefix: str) -> list[str]: + # Replace None level names with defaults: f"{prefix}{i}" + return [n if n is not None else f"{prefix}{i}" for i, n in enumerate(names)] + +def _df_to_zarr_via_xarray( + df: pd.DataFrame, + *, + path: Path | None = None, + store: "zarr.storage.Store" | None = None, + group_name: str | None = None, + compressor=None, + mode: str = "w", +) -> None: + """ + Write DataFrame (supports MultiIndex) to a Zarr store/group via xarray. + If df.index is MultiIndex, we reset it into coordinate variables and record attrs + so we can reconstruct on read. + + Provide either path (DirectoryStore path) or (store, group). + """ + if xr is None or zarr is None: + raise ImportError( + "xarray/zarr not installed. pip install xarray zarr numcodecs" + ) + + row_prefix = "row" + col_prefix = "col" + + # Ensure all index/column level names are defined + if isinstance(df.index, pd.MultiIndex): + row_names = _default_names(list(df.index.names), row_prefix) + else: + if not df.index.name: + df.index.name = f"{row_prefix}0" + row_names = [df.index.name] + + if isinstance(df.columns, pd.MultiIndex): + col_names = _default_names(list(df.columns.names), col_prefix) + else: + if not df.columns.name: + df.columns.name = f"{col_prefix}0" + col_names = [df.columns.name] + + # stack ALL column levels to move them into the row index; result index + # levels = row_names + col_names + # Series with MultiIndex index + series = df.stack(col_names, future_stack=True) + + # Build DataArray (dims are level names of the Series index, in order) + da = xr.DataArray.from_series(series).rename("values") + ds = da.to_dataset() + + # Mark attrs for round-trip (which dims belong to rows vs columns) + ds.attrs["__via"] = "pd_df_any_mi" + ds.attrs["__row_dims__"] = row_names + ds.attrs["__col_dims__"] = col_names + + # check size to determine chunking + chunking = {} + for name, size in ds.sizes.items(): + if size > BIG_CHUNKS: + chunking[name] = BIG_CHUNKS + else: + chunking[name] = SMALL_CHUNKS + + ds = ds.chunk(chunking) + + if compressor is None: + compressor = _make_default_compressor() + + # compressor & object codec + encoding = { + "values": {"compressor": compressor} + } + # ensure coords are writable (handle object/string coords): cast to str + for cname, coord in ds.coords.items(): + if coord.dtype == object: + ds = ds.assign_coords({cname: coord.astype(str)}) + + # Write + if path is not None: + ds.to_zarr( + str(path), + mode=mode, + encoding=encoding, + ) + try: + zarr.consolidate_metadata(str(path)) + except Exception: + pass + else: + assert store is not None + ds.to_zarr( + store=store, + group=group_name or "", + mode=mode, + encoding=encoding, + ) + + +def _df_from_zarr_via_xarray( + *, + path: Path | None = None, + store: "zarr.storage.Store" | None = None, + group_name: str | None = None, +) -> pd.DataFrame: + """ + Read a DataFrame written by _df_to_zarr_via_xarray and reconstruct + MultiIndex if attrs exist. + Provide either path or (store, group). + """ + if xr is None or zarr is None: + raise ImportError( + "xarray/zarr not installed. pip install xarray zarr numcodecs" + ) + + if path is not None: + ds = xr.open_zarr( + str(path), + consolidated=True, + chunks="auto", + ) + else: + ds = xr.open_zarr( + store=store, + group=group_name or "", + consolidated=False, + chunks="auto", + ) + + da = ds["values"] + row_dim = list(ds.attrs.get("__row_dims__") or []) + col_dim = list(ds.attrs.get("__col_dims__") or []) + + # Series with MultiIndex index (row_dim, *col_dim) + series = da.to_series() + + # If there are column dims, unstack them back to columns + if col_dim: + df = series.unstack(col_dim) + else: + # No column dims -> a single column DataFrame + df = series.to_frame(name="values") + + col_name = [df.columns.name] + if not (row_dim == df.index.names): + df.index.set_names(row_dim, inplace=True) + if not (col_dim == col_name): + if isinstance(df.columns, pd.MultiIndex): + df.columns.set_names(col_dim, inplace=True) + else: + df.columns.name = col_dim[0] + + return df + + +# ----------------------------------- +# Zarr read/write for arrays and dicts +# ----------------------------------- + +def _normalise_1d_chunks(chunks: Any, n: int) -> Any: + if chunks is None: + return None + if isinstance(chunks, int): + return (min(chunks, n),) + if isinstance(chunks, (tuple, list)): + if len(chunks) == 0: + return None + return (min(int(chunks[0]), n),) + return None + + +def _write_arrays_dicts_to_zarr( + root_path: Path, + obj: Any, + *, + chunks: Any = None, + compressor: Any = None, + overwrite: bool = False, +) -> None: + """ + Write ndarray or dict/nested-dict of ndarrays/DataFrames into a Zarr + directory. + DataFrames inside dicts are written via xarray into corresponding groups. + Top-level pure DataFrame should be written with _df_to_zarr_via_xarray + instead. + """ + if zarr is None: + raise ImportError( + "zarr/numcodecs not installed. pip install zarr numcodecs" + ) + + store = zarr.DirectoryStore(str(root_path)) + if overwrite and root_path.exists(): + shutil.rmtree(root_path) + root = zarr.group( + store=store, + overwrite=overwrite or (not root_path.exists()) + ) + + def write_into(prefix: str, value: Any): + # prefix: group path relative to root ("" for root) + if isinstance(value, np.ndarray): + g = zarr.open_group(store=store, path=prefix, mode="a") + name = "array" if prefix == "" else prefix.split("/")[-1] + # In groups, datasets live as siblings; for arrays we use the + # current group's name + # Better: use a fixed name for standalone arrays in a group + # Here we store as "values" for groups, or "array" at root if + # top-level ndarray + ds_name = "array" if prefix == "" else "values" + if ds_name in g: + del g[ds_name] + ds_chunks = chunks + if isinstance(chunks, int) and value.ndim == 1: + ds_chunks = _normalise_1d_chunks(chunks, len(value)) + g.create_dataset( + name=ds_name, + data=value, + chunks=ds_chunks, + compressor=compressor, + ) + + elif isinstance(value, pd.DataFrame): + # Write DF via xarray under this group path + _df_to_zarr_via_xarray( + value, + store=store, + group_name=prefix or "", + compressor=compressor, + mode="w", + ) + + elif isinstance(value, dict): + # Recurse for each item + for k, v in value.items(): + key = _safe_key(k) + next_prefix = f"{prefix}/{key}" if prefix else key + # Ensure group exists + zarr.open_group(store=store, path=prefix, mode="a") + write_into(next_prefix, v) + + else: + raise TypeError( + "Zarr backend supports ndarray, DataFrame, or dicts of them. " + f"Got: {type(value)} at group '{prefix or '/'}'" + ) + + if isinstance(obj, dict): + for k, v in obj.items(): + write_into(_safe_key(k), v) + elif isinstance(obj, np.ndarray): + write_into("", obj) + else: + raise TypeError( + "Top-level object must be ndarray or dict for this writer." + ) + + # Optional consolidation (best when writing via path) + try: + zarr.consolidate_metadata(str(root_path)) + except Exception: + pass + + +def _read_zarr_generic(root_path: Path) -> Any: + """ + Read back what _write_arrays_dicts_to_zarr wrote and also detect DataFrame + groups written via xarray (by checking group attrs['__via']). + + Returns: + - DataFrame (if top-level was DF written via xarray) + - zarr.Array (if top-level ndarray) + - dict tree mixing DataFrames and zarr.Arrays + """ + if zarr is None: + raise ImportError( + "zarr/numcodecs not installed. pip install zarr numcodecs" + ) + + store = zarr.DirectoryStore(str(root_path)) + if not root_path.exists(): + return None + root = zarr.open_group(store=store, mode="r") + + # If top-level was written via xarray as a DataFrame + if root.attrs.get("__via") == "pd_df_any_mi" and xr is not None: + return _df_from_zarr_via_xarray(store=store, group_name="") + + # If top-level is a single array written at root + if "array" in root and isinstance(root["array"], zarr.Array): + return root["array"] + + def read_from_group(prefix: str) -> Any: + g = zarr.open_group(store=store, path=prefix, mode="r") + # DataFrame group? + if g.attrs.get("__via") == "pd_df_any_mi" and xr is not None: + return _df_from_zarr_via_xarray( + store=store, + group_name=prefix or "", + ) + + out: dict[str, Any] = {} + for name, node in g.items(): + full = f"{prefix}/{name}" if prefix else name + if isinstance(node, zarr.Array): + # Arrays inside groups are stored as "values" + if name == "values" and prefix: + out[prefix.split("/")[-1]] = node + else: + out[name] = node + elif isinstance(node, zarr.hierarchy.Group): + # Recurse + res = read_from_group(full) + # If res is a DF and 'name' was only a container, store under + # that key + out[name] = res + return out + + return read_from_group("") + + +def _filter_reserved_kwargs(fn, reserved: dict[str, Any]) -> dict[str, Any]: + """ + Return subset of `reserved` that `fn` will accept (has a parameter by that + name, or has **kwargs). + """ + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return {} + # if method has **kwargs, pass everything + if any( + p.kind == inspect.Parameter.VAR_KEYWORD + for p in sig.parameters.values() + ): + return reserved + + # otherwise pass only the names it declares + return {k: v for k, v in reserved.items() if k in sig.parameters} + +# ----------------------- +# Decorator with Zarr +# ----------------------- + +def cacheable( + _func=None, + *, + cache_format: str | None = None, + zarr_chunks: Any = None, + zarr_compressor: Any = None, + zarr_dim_name: str | None = None, +): + """ + Decorator factory for caching. + + Usage: + @cacheable # default HDF5 + def f(...): ... + + @cacheable(cache_format='zarr') # Zarr default for this method + def g(...): ... + + Backend precedence: + per-call kwarg cache_format > per-method (decorator args) + > instance default self._cache_format > 'hdf5' + + Zarr options: + - zarr_chunks: int (rows per chunk for DataFrame; 1D arrays) or tuple/dict for arrays/xarray + - zarr_compressor: a numcodecs compressor (e.g., Blosc(...)); default zstd+bitshuffle + - zarr_dim_name: optional row-dimension name when writing DataFrame via xarray + """ + default_backend = cache_format or "hdf5" + + def decorator(method): + @wraps(method) + def wrapper(*args, **kwargs): + name = kwargs.pop("name", None) + + # Per-call overrides + per_call_backend = kwargs.pop("cache_format", None) + per_call_chunks = kwargs.pop("zarr_chunks", zarr_chunks) + per_call_compressor = kwargs.pop("zarr_compressor", zarr_compressor) + per_call_dim_name = kwargs.pop("zarr_dim_name", zarr_dim_name) + + inst = args[0] + + # Units gating (unchanged) + if "units" in kwargs: + units = kwargs["units"] + if not isinstance(units, SelectedUnits) or not hasattr(units, "name"): + return method(*args, **kwargs) + + if not getattr(inst, "_use_cache", True): + return method(*args, **kwargs) + + # Build key parts (unchanged) + self_, *as_list = list(args) + list(kwargs.values()) + arrays = [i for i, arg in enumerate(as_list) if isinstance(arg, np.ndarray)] + if arrays: + if name is None: + raise PixelsError( + "Cacheing methods when passing arrays requires also " + "passing name='something'" + ) + for i in arrays: + as_list[i] = name + + key_parts = [method.__name__] + [ + str(i.name) if hasattr(i, "name") else str(i) for i in as_list + ] + base = inst.cache / ("_".join(key_parts) + f"_{inst.stream_id}") + + backend = per_call_backend\ + or getattr(inst, "_cache_format", None)\ + or default_backend + + # HDF5 backend + if backend == "hdf5": + cache_path = base.with_name(base.name + ".h5") + if cache_path.exists() and inst._use_cache != "overwrite": + try: + df = ioutils.read_hdf5(cache_path) + logging.info(f"\n> Cache loaded from {cache_path}.") + except HDF5ExtError: + df = None + logging.info("\n> df is None, cache does not exist.") + except (KeyError, ValueError): + df = {} + with pd.HDFStore(cache_path, "r") as store: + for key in store.keys(): + parts = key.lstrip("/").split("/") + if len(parts) == 1: + df[parts[0]] = store[key] + elif len(parts) == 2: + stream, nm = parts[0], "/".join(parts[1:]) + df.setdefault(stream, {})[nm] = store[key] + logging.info(f"\n> Cache loaded from {cache_path}.") + else: + df = method(*args, **kwargs) + cache_path.parent.mkdir(parents=True, exist_ok=True) + if df is None: + cache_path.touch() + logging.info("\n> df is None, cache will exist but be empty.") + else: + if isinstance(df, dict): + if ioutils.is_nested_dict(df): + for probe_id, nested_dict in df.items(): + for nm, values in nested_dict.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=f"/{probe_id}/{nm}", + mode="a", + ) + else: + for nm, values in df.items(): + ioutils.write_hdf5( + path=cache_path, df=values, key=nm, mode="a" + ) + else: + ioutils.write_hdf5(cache_path, df) + return df + + # Zarr backend (with DataFrame via xarray, MultiIndex supported) + if backend == "zarr": + if zarr is None: + raise ImportError( + "cache_format='zarr' requires zarr. pip install zarr numcodecs xarray" + ) + zarr_path = base.with_name(base.name + ".zarr") + can_read = zarr_path.exists() and inst._use_cache != "overwrite" + + if can_read: + try: + obj = _read_zarr_generic(zarr_path) + logging.info(f"\n> Zarr cache loaded from {zarr_path}.") + return obj + except Exception as e: + logging.info(f"\n> Failed to read Zarr cache ({e}); recomputing.") + + # inject reserved kwargs so the method can write directly to + # store, if the method accepts + reserved = {"_zarr_out": zarr_path} + kwargs.update(_filter_reserved_kwargs(method, reserved)) + + # Compute fresh + result = method(*args, **kwargs) + if result is None: + # Method handled writing itself; read and return + obj = _read_zarr_generic(zarr_path) + logging.info(f"\n> Zarr cache written to {zarr_path}.") + return obj + + # Overwrite + if inst._use_cache == "overwrite" and zarr_path.exists(): + shutil.rmtree(zarr_path) + + compressor = per_call_compressor or _make_default_compressor() + + # DataFrame via xarray (works for MultiIndex) + if isinstance(result, pd.DataFrame): + _df_to_zarr_via_xarray( + result, + path=zarr_path, + compressor=compressor, + mode="w", + ) + logging.info( + f"\n> Zarr cache (DataFrame via xarray) written to {zarr_path}." + ) + return result + + # Dict/nested-dict of arrays or DataFrames + if isinstance(result, dict) or isinstance(result, np.ndarray): + _write_arrays_dicts_to_zarr( + zarr_path, + result, + chunks=per_call_chunks, + compressor=compressor, + overwrite=True, + ) + logging.info(f"\n> Zarr cache written to {zarr_path}.") + return result + + # Fallback for unsupported types: write HDF5 like before + logging.warning( + "cache_format='zarr' requested but result type " + "not supported for Zarr; falling back to HDF5." + ) + h5_fallback = base.with_suffix(".h5") + if isinstance(result, dict): + if ioutils.is_nested_dict(result): + for probe_id, nested_dict in result.items(): + for nm, values in nested_dict.items(): + ioutils.write_hdf5( + path=h5_fallback, + df=values, + key=f"/{probe_id}/{nm}", + mode="a", + ) + else: + for nm, values in result.items(): + ioutils.write_hdf5( + path=h5_fallback, + df=values, + key=str(nm), + mode="a", + ) + else: + ioutils.write_hdf5(h5_fallback, result) + logging.info(f"\n> Cache written to {h5_fallback} (fallback).") + return result + + raise ValueError(f"Unknown cache_format/backend: {backend}") + + return wrapper + + if _func is None: + return decorator + else: + return decorator(_func) + +''' +import numpy as np +import pandas as pd +from tables import HDF5ExtError + +from pixels.configs import * +from pixels import ioutils +from pixels.error import PixelsError +from pixels.units import SelectedUnits + +def cacheable(method): + """ + Methods with this decorator will have their output cached to disk so that + future calls with the same set of arguments will simply load the result from + disk. However, from pixels.error import PixelsError if the key word argument + list contains `units` and it is not either `None` or an instance of + `SelectedUnits` then this is disabled. + """ + def wrapper(*args, **kwargs): + name = kwargs.pop("name", None) + + if "units" in kwargs: + units = kwargs["units"] + if not isinstance(units, SelectedUnits) or not hasattr(units, "name"): + return method(*args, **kwargs) + + self, *as_list = list(args) + list(kwargs.values()) + if not self._use_cache: + return method(*args, **kwargs) + + arrays = [i for i, arg in enumerate(as_list) if isinstance(arg, np.ndarray)] + if arrays: + if name is None: + raise PixelsError( + "Cacheing methods when passing arrays requires also " + "passing name='something'" + ) + for i in arrays: + as_list[i] = name + + # build a key: method name + all args + key_parts = [method.__name__] + [str(i.name) if hasattr(i, "name") + else str(i) for i in as_list] + cache_path = self.cache /\ + ("_".join(key_parts) + f"_{self.stream_id}.h5") + + if cache_path.exists() and self._use_cache != "overwrite": + # load cache + try: + df = ioutils.read_hdf5(cache_path) + logging.info(f"\n> Cache loaded from {cache_path}.") + except HDF5ExtError: + df = None + logging.info("\n> df is None, cache does not exist.") + except (KeyError, ValueError): + # if key="df" is not found, then use HDFStore to list and read + # all dfs + # create df as a dictionary to hold all dfs + df = {} + with pd.HDFStore(cache_path, "r") as store: + # list all keys + for key in store.keys(): + # remove "/" in key and split + parts = key.lstrip("/").split("/") + if len(parts) == 1: + # use the only key name as dict key + df[parts[0]] = store[key] + elif len(parts) == 2: + # stream id is the first, data name is the second + stream, name = parts[0], "/".join(parts[1:]) + df.setdefault(stream, {})[name] = store[key] + logging.info(f"\n> Cache loaded from {cache_path}.") + else: + df = method(*args, **kwargs) + cache_path.parent.mkdir(parents=True, exist_ok=True) + if df is None: + cache_path.touch() + logging.info("\n> df is None, cache will exist but be empty.") + else: + # allows to save multiple dfs in a dict in one hdf5 file + if isinstance(df, dict): + if ioutils.is_nested_dict(df): + for probe_id, nested_dict in df.items(): + # NOTE: we remove `.ap` in stream id cuz having `.`in + # the key name get problems + for name, values in nested_dict.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=f"/{probe_id}/{name}", + mode="a", + ) + else: + for name, values in df.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=name, + mode="a", + ) + else: + ioutils.write_hdf5(cache_path, df) + return df + return wrapper +''' diff --git a/pixels/experiment.py b/pixels/experiment.py index 7346c85..f7f910b 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -12,6 +12,7 @@ from pixels import ioutils from pixels.error import PixelsError +from pixels.configs import * class Experiment: @@ -50,7 +51,10 @@ def __init__( data_dir, meta_dir=None, interim_dir=None, + hist_dir=None, + processed_dir=None, session_date_fmt="%y%m%d", + of_date=None, ): if not isinstance(mouse_ids, (list, tuple, set)): mouse_ids = [mouse_ids] @@ -69,17 +73,29 @@ def __init__( else: self.meta_dir = None + if hist_dir: + self.hist_dir = Path(hist_dir).expanduser() + self.sessions = [] - sessions = ioutils.get_sessions(mouse_ids, self.data_dir, self.meta_dir, session_date_fmt) + sessions = ioutils.get_sessions( + mouse_ids, + self.data_dir, + self.meta_dir, + session_date_fmt, + of_date, + ) for name, metadata in sessions.items(): - assert len(set(s['data_dir'] for s in metadata)) == 1, "All JSON items with same day must use same data folder." + assert len(set(s['data_dir'] for s in metadata)) == 1,\ + "All JSON items with same day must use same data folder." self.sessions.append( behaviour( name, metadata=[s['metadata'] for s in metadata], data_dir=metadata[0]['data_dir'], interim_dir=interim_dir, + hist_dir=hist_dir, + processed_dir=processed_dir, ) ) @@ -107,21 +123,23 @@ def set_cache(self, on): for session in self.sessions: session.set_cache(on) - def process_spikes(self): + def extract_ap(self): """ - Process the spike data from the raw neural recording data for all sessions. + Process the ap band from the raw neural recording data for all sessions. """ for i, session in enumerate(self.sessions): - print(">>>>> Processing spikes for session {} ({} / {})" + print(">>>>> Processing ap band for session {} ({} / {})" .format(session.name, i + 1, len(self.sessions))) - session.process_spikes() + session.extract_ap() - def sort_spikes(self): + def sort_spikes(self, mc_method="dredge"): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): - print(">>>>> Sorting spikes for session {} ({} / {})" - .format(session.name, i + 1, len(self.sessions))) - session.sort_spikes() + logging.info( + "\n>>>>> Sorting spikes for session " + f"{session.name} ({i + 1} / {len(self.sessions)})" + ) + session.sort_spikes(mc_method=mc_method) def assess_noise(self): """ @@ -132,14 +150,15 @@ def assess_noise(self): .format(session.name, i + 1, len(self.sessions))) session.assess_noise() - def process_lfp(self): + def extract_bands(self): """ - Process the LFP data from the raw neural recording data for all sessions. + Extract ap & lfp data from the raw neural recording data for all + sessions. """ for i, session in enumerate(self.sessions): - print(">>>>> Processing LFP data for session {} ({} / {})" + print(">>>>> Extracting ap & lfp data for session {} ({} / {})" .format(session.name, i + 1, len(self.sessions))) - session.process_lfp() + session.extract_bands() def process_behaviour(self): """ @@ -205,10 +224,18 @@ def select_units(self, *args, **kwargs): Select units based on specified criteria. The output of this can be passed to some other methods to apply those methods only to these units. """ - units = [] + units = {} for i, session in enumerate(self.sessions): - units.append(session.select_units(*args, **kwargs)) + name = session.name + selected = session.select_units(*args, **kwargs) + if len(selected) == 0: + logging.warning( + f"\n> {name} does not have units in {selected}, " + "skip." + ) + else: + units[name] = selected return units @@ -219,14 +246,19 @@ def align_trials(self, *args, units=None, **kwargs): """ trials = {} for i, session in enumerate(self.sessions): + name = session.name result = None if units: - if units[i]: - result = session.align_trials(*args, units=units[i], **kwargs) + if units[name]: + result = session.align_trials( + *args, + units=units[name], + **kwargs, + ) else: result = session.align_trials(*args, **kwargs) if result is not None: - trials[i] = result + trials[name] = result if "motion_tracking" in args: df = pd.concat( @@ -235,6 +267,28 @@ def align_trials(self, *args, units=None, **kwargs): names=["session", "trial", "scorer", "bodyparts", "coords"] ) + if "trial_rate" in kwargs.values(): + level_names = ["session", "stream", "unit", "trial"] + fr = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="fr", + level_names=level_names, + ) + spiked = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="spiked", + level_names=level_names, + ) + positions = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="positions", + level_names=["session", "stream", "start", "trial"], + ) + df = { + "fr": fr, + "spiked": spiked, + "positions": positions, + } else: df = pd.concat( trials.values(), axis=1, copy=False, @@ -259,6 +313,40 @@ def get_cluster_info(self): """ return [s.get_cluster_info() for s in self.sessions] + def get_good_units_info(self): + """ + Get some basic high-level information for each good unit. This is mostly just the + information seen in the table in phy, plus their region. + """ + info = {} + + for m, mouse in enumerate(self.mouse_ids): + info[mouse] = {} + mouse_sessions = [] + for session in self.sessions: + if mouse in session.name: + mouse_sessions.append(session) + + for i, session in enumerate(mouse_sessions): + info[mouse][i] = session.get_good_units_info() + + long_df = pd.concat( + info[mouse], + axis=0, + names=["session", "unit_idx"], + ) + info[mouse] = long_df + + mouse_info = info + + info_pooled = pd.concat( + info, axis=0, copy=False, + keys=info.keys(), + names=["mouse", "session", "unit_idx"], + ) + + return mouse_info, info_pooled + def get_spike_widths(self, units=None): """ Get the widths of spikes for units matching the specified criteria. @@ -291,7 +379,10 @@ def get_spike_waveforms(self, units=None): waveforms[i] = session.get_spike_waveforms(units=units[i]) else: waveforms[i] = session.get_spike_waveforms() + assert 0 + df.add_prefix(i) + #TODO: get concat waveforms for each mouse df = pd.concat( waveforms.values(), axis=1, copy=False, keys=waveforms.keys(), @@ -299,6 +390,83 @@ def get_spike_waveforms(self, units=None): ) return df + def get_waveform_metrics(self, units=None): + """ + Get waveform metrics of mean waveform for units matching the specified + criteria; separated by mouse. + """ + waveform_metrics = {} + + for m, mouse in enumerate(self.mouse_ids): + mouse_sessions = [] + waveform_metrics[mouse] = {} + for session in self.sessions: + if mouse in session.name: + mouse_sessions.append(session) + + for i, session in enumerate(mouse_sessions): + if units: + if units[i]: + waveform_metrics[mouse][i] = session.get_waveform_metrics(units=units[i]) + else: + waveform_metrics[mouse][i] = session.get_waveform_metrics() + + long_df = pd.concat( + waveform_metrics[mouse], + axis=0, + names=["session", "unit_idx"], + ) + # drop nan rows + long_df.dropna(inplace=True) + waveform_metrics[mouse] = long_df + + mouse_waveform_metrics = waveform_metrics + + waveform_metrics_pooled = pd.concat( + waveform_metrics, axis=0, copy=False, + keys=waveform_metrics.keys(), + names=["mouse", "session", "unit_idx"], + ) + + return mouse_waveform_metrics, waveform_metrics_pooled + + def get_spike_times(self, units): + """ + Get spike times of each units, separated by mouse. + """ + spike_times = {} + + for m, mouse in enumerate(self.mouse_ids): + mouse_sessions = [] + spike_times[mouse] = {} + for session in self.sessions: + if mouse in session.name: + mouse_sessions.append(session) + + for i, session in enumerate(mouse_sessions): + spike_times[mouse][i] = session._get_spike_times()[units[i]] + #spike_times[mouse][i] = spike_times[mouse][i].add_prefix(f'{i}_') + + df = pd.concat( + spike_times[mouse], + axis=1, + names=["session", "unit"], + ) + spike_times[mouse] = df + + mouse_spike_times = spike_times + + """ + spike_times_pooled = pd.concat( + mouse_spike_times, axis=1, copy=False, + keys=mouse_spike_times.keys(), + names=["mouse", "session", "unit"], + ) + """ + + return mouse_spike_times + + def get_aligned_spike_rate_CI(self, *args, units=None, **kwargs): """ Get the confidence intervals of the mean firing rates within a window aligned to @@ -325,3 +493,138 @@ def get_session_by_name(self, name: str): if session.name == name: return session raise PixelsError + + + def get_positional_data(self, *args, units=None, **kwargs): + """ + Get positional firing rate for aligned vr trials. + Check behaviours.base.Behaviour.get_positional_data for usage + information. + """ + trials = {} + for i, session in enumerate(self.sessions): + name = session.name + result = None + if units: + if units[name]: + result = session.get_positional_data( + *args, + units=units[name], + **kwargs, + ) + else: + result = session.get_positional_data(*args, **kwargs) + if result is not None: + trials[name] = result + + level_names = ["session", "stream", "start", "unit", "trial"] + pos_fr = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="pos_fr", + level_names=level_names, + ) + pos_fc = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="pos_fc", + level_names=level_names, + ) + occupancies = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="occupancy", + level_names=["session", "stream", "trial"], + ) + df = { + "pos_fr": pos_fr, + "pos_fc": pos_fc, + "occupancy": occupancies, + } + + return df + + + def get_binned_trials(self, *args, units=None, **kwargs): + """ + Get binned firing rate and spike count for aligned vr trials. + Check behaviours.base.Behaviour.get_binned_trials for usage information. + """ + # TODO jun 21 2025: + # can we combine this func with get_positional_data since they are + # basically the same, we just need to add `use_binned` bool in the arg + session_names = [session.name for session in self.sessions] + trials = {} + if not units is None: + for name in units.keys(): + session = self.sessions[session_names.index(name)] + trials[name] = session.get_binned_trials( + *args, + units=units[name], + **kwargs, + ) + else: + for i, session in enumerate(self.sessions): + name = session.name + result = session.get_binned_trials(*args, **kwargs) + if not result is None: + trials[name] = result + + level_names = ["session", "stream", "start", "unit", "trial"] + bin_fr = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="pos_fr", + level_names=level_names, + ) + bin_fc = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="pos_fc", + level_names=level_names, + ) + bin_occupancies = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="occupancy", + level_names=["session", "stream", "trial"], + ) + df = { + "bin_fr": bin_fr, + "bin_fc": bin_fc, + "bin_occupancy": bin_occupancies, + } + + return df + + + def sync_vr(self, vr): + """ + Synchronise virtual reality data of a mouse (or mice) with pixels + streams. + """ + trials = {} + for i, session in enumerate(self.sessions): + # vr is a vision-in-darkness mouse object + vr_session = vr.sessions[i] + assert session.name.split("_")[0] in vr_session.name + + session.sync_vr(vr_session) + + return None + + + def get_spike_chance(self, *args, **kwargs): + chance = {} + for i, session in enumerate(self.sessions): + name = session.name + chance[name] = session.get_spike_chance(*args, **kwargs) + + return chance + + + def get_binned_chance(self, *args, **kwargs): + """ + Get binned chance firing rate and spike count for aligned vr trials. + Check behaviours.base.Behaviour.get_binned_chance for usage information. + """ + binned = {} + for i, session in enumerate(self.sessions): + name = session.name + binned[name] = session.get_binned_chance(*args, **kwargs) + + return binned diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 73b64e1..03f153b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -17,6 +17,7 @@ from nptdms import TdmsFile from pixels.error import PixelsError +from pixels.configs import * def get_data_files(data_dir, session_name): @@ -33,91 +34,199 @@ def get_data_files(data_dir, session_name): Returns ------- - A list of dicts, where each dict corresponds to one recording. The dict will contain - these keys to identify data files: - - - spike_data - - spike_meta - - lfp_data - - lfp_meta - - behaviour - - camera_data - - camera_meta - + A nested dicts, where each dict corresponds to one session. Data is + separated to two main categories: pixels and behaviour. + + In `pixels`, data is separated by their stream id, this is to allow: + - easy concatenation of recordings files from the same probe, i.e., + stream id, + - different numbers of pixels recordings and behaviour recordings, + + {session_name:{ + "pixels":{ + "imec0":{ + "ap_raw": [PosixPath("name.bin")], + "ap_meta": [PosixPath("name.meta")], + "preprocessed": spikeinterface recording obj, + "ap_extracted": spikeinterface recording obj, + "ap_whitened": spikeinterface recording obj, + "lfp_extracted": spikeinterface recording obj, + "surface_depth": PosixPath("name.yaml"), + "sorting_analyser": PosixPath("name.zarr"), + }, + "imecN":{ + }, + }, + "behaviour":{ + "vr": PosixPath("name.h5"), + "action_labels": PosixPath("name.npz"), + }, + } """ if session_name != data_dir.stem: - data_dir = list(data_dir.glob(f'{session_name}*'))[0] - files = [] + data_dir = list(data_dir.glob(f"{session_name}*"))[0] + + files = {} - spike_data = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*')) - spike_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*')) - lfp_data = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].lf.bin*')) - lfp_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].lf.meta*')) - behaviour = sorted(glob.glob(f'{data_dir}/[0-9a-zA-Z_-]*([0-9]).tdms*')) + ap_raw = sorted(glob.glob(f"{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*")) + ap_meta = sorted(glob.glob(f"{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*")) - if not spike_data: + if not ap_raw: raise PixelsError(f"{session_name}: could not find raw AP data file.") - if not spike_meta: + if not ap_meta: raise PixelsError(f"{session_name}: could not find raw AP metadata file.") - if not lfp_data: - raise PixelsError(f"{session_name}: could not find raw LFP data file.") - if not lfp_meta: - raise PixelsError(f"{session_name}: could not find raw LFP metadata file.") - - camera_data = [] - camera_meta = [] - for rec in behaviour: - name = Path(rec).stem - rec_vids = sorted(glob.glob(f'{data_dir}/*{name}-*.tdms*')) - vids = [v for v in rec_vids if 'meta' not in v] - camera_data.append(vids) - meta = [v for v in rec_vids if 'meta' in v] - camera_meta.append(meta) - - for num, spike_recording in enumerate(spike_data): - recording = {} - recording['spike_data'] = original_name(spike_recording) - recording['spike_meta'] = original_name(spike_meta[num]) - recording['lfp_data'] = original_name(lfp_data[num]) - recording['lfp_meta'] = original_name(lfp_meta[num]) - - if behaviour: - if len(behaviour) == len(spike_data): - recording['behaviour'] = original_name(behaviour[num]) - else: - recording['behaviour'] = original_name(behaviour[0]) - recording['behaviour_processed'] = recording['behaviour'].with_name( - recording['behaviour'].stem + '_processed.h5' - ) - # We only have videos if we also have behavioural TDMS data - if len(camera_data) > num: - recording['camera_data'] = [original_name(d) for d in camera_data[num]] - recording['camera_meta'] = [original_name(d) for d in camera_meta[num]] - recording['motion_index'] = Path(f'motion_index_{num}.npy') - recording['motion_tracking'] = Path(f'motion_tracking_{num}.h5') - else: - recording['behaviour'] = None - recording['behaviour_processed'] = None + pupil_raw = sorted(glob.glob(f"{data_dir}/behaviour/pupil_cam/*.avi*")) + + behaviour = { + "vr_synched": [], + "action_labels": [], + "pupil_raw": pupil_raw, + } + + pixels = {} + for r, rec in enumerate(ap_raw): + stream_id = rec[-12:-4] + probe_id = stream_id[:-3] + # separate recordings by their stream ids + if stream_id not in pixels: + pixels[stream_id] = { + "ap_raw": [], # there could be mutliple, thus list + "ap_meta": [], + "si_rec": None, # there could be only one, thus None + "preprocessed": None, + "ap_extracted": None, + "ap_whitened": None, + "lfp_extracted": None, + "CatGT_ap_data": [], + "CatGT_ap_meta": [], + } + + base_name = original_name(rec) + pixels[stream_id]["ap_raw"].append(base_name) + pixels[stream_id]["ap_meta"].append(original_name(ap_meta[r])) + + behaviour["vr_synched"].append(base_name.with_name( + f"{session_name}_{probe_id}_vr_synched.h5" + )) + behaviour["action_labels"].append(base_name.with_name( + f"action_labels_{probe_id}.npz" + )) + + # >>> spikeinterface cache >>> + # extracted & motion corrected ap stream, 300Hz+ + pixels[stream_id]["ap_motion_corrected"] = base_name.with_name( + f"{base_name.stem}.mcd.zarr" + ) + # extracted & motion corrected lfp stream, 500Hz- + pixels[stream_id]["lfp_motion_corrected"] = base_name.with_name( + f"{base_name.stem[:-3]}.lf.mcd.zarr" + ) + pixels[stream_id]["detected_peaks"] = base_name.with_name( + f"{base_name.stem}_detected_peaks.h5" + ) + sorted_stream_dir = base_name.parent / f"sorted_stream_{probe_id[-1]}" + pixels[stream_id]["sorting_analyser"] = sorted_stream_dir /\ + "curated_sa.zarr" + # if performed units merge, we have an updated sorting analyser + pixels[stream_id]["merged_sorting_analyser"] = sorted_stream_dir /\ + "merged_sa.zarr" + # <<< spikeinterface cache <<< + + # depth info of probe + pixels[stream_id]["surface_depth"] = base_name.with_name( + f"{session_name}_{probe_id}_surface_depth.yaml" + ) + pixels[stream_id]["clustered_channels"] = base_name.with_name( + f"{session_name}_{stream_id}_channel_clustering_results.h5" + ) + + # psd of theta and ripple to identify CA1 pyramidal layer + pixels[stream_id]["bandwise_psd"] = base_name.with_name( + f"{session_name}_{probe_id}_bandwise_psd.h5" + ) - recording['action_labels'] = Path(f'action_labels_{num}.npy') - recording['spike_processed'] = recording['spike_data'].with_name( - recording['spike_data'].stem + '_processed.h5' + # TODO mar 5 2025: + # maybe do NOT put shuffled data in here, cuz there will be different + # trial conditions, better to cache them??? + + # shuffled response for each unit, in light & dark conditions, to get + # the chance + # memmaps for temporary storage + pixels[stream_id]["spiked_shuffled_memmap"] = base_name.with_name( + f"{session_name}_{probe_id}_spiked_shuffled.bin" + ) + pixels[stream_id]["fr_shuffled_memmap"] = base_name.with_name( + f"{session_name}_{probe_id}_fr_shuffled.bin" ) - recording['spike_rate_processed'] = Path(f'spike_rate_{num}.h5') - recording['lfp_processed'] = recording['lfp_data'].with_name( - recording['lfp_data'].stem + '_processed.npy' + pixels[stream_id]["shuffled_shape"] = base_name.with_name( + f"{session_name}_{probe_id}_shuffled_shape.json" ) - recording['lfp_sd'] = recording['lfp_data'].with_name( - recording['lfp_data'].stem + '_sd.json' + pixels[stream_id]["shuffled_index"] = base_name.with_name( + f"{session_name}_{probe_id}_shuffled_index.h5" ) - recording['clustered_channels'] = recording['lfp_data'].with_name( - f'channel_clustering_results_{num}.h5' + pixels[stream_id]["shuffled_columns"] = base_name.with_name( + f"{session_name}_{probe_id}_shuffled_columns.h5" + ) + # .h5 files + pixels[stream_id]["spiked_shuffled"] = base_name.with_name( + f"{session_name}_{probe_id}_spiked_shuffled.h5" + ) + pixels[stream_id]["fr_shuffled"] = base_name.with_name( + f"{session_name}_{probe_id}_fr_shuffled.h5" + ) + + # noise in curated units + pixels[stream_id]["noisy_units"] = base_name.with_name( + f"{session_name}_{probe_id}_noisy_units.yaml" ) - recording['depth_info'] = recording['lfp_data'].with_name( - f'depth_info_{num}.json' + # mergeable units in curated units + pixels[stream_id]["mergeable_units"] = base_name.with_name( + f"{session_name}_{probe_id}_mergeable_units.yaml" ) - files.append(recording) + + # old catgt data + pixels[stream_id]["CatGT_ap_data"].append( + str(base_name).replace("t0", "tcat") + ) + pixels[stream_id]["CatGT_ap_meta"].append( + str(base_name).replace("t0", "tcat") + ) + + # histology + mouse_id = session_name.split("_")[-1] + pixels[stream_id]["depth_info"] = base_name.with_name( + f"{mouse_id}_depth_info.yaml" + ) + + # identified faulty channels + pixels[stream_id]["faulty_channels"] = base_name.with_name( + f"{session_name}_{probe_id}_faulty_channels.yaml" + ) + + #pixels[stream_id]["spike_rate_processed"] = base_name.with_name( + # f"spike_rate_{stream_id}.h5" + #) + + if pupil_raw: + behaviour["pupil_processed"] = [] + behaviour["motion_index"] = [] + behaviour["motion_tracking"] = [] + for r, rec in enumerate(pupil_raw): + behaviour["pupil_processed"].append(base_name.with_name( + session_name + "_pupil_processed.h5" + )) + behaviour["motion_index"] = base_name.with_name( + session_name + "_motion_index.npz" + ) + behaviour["motion_tracking"] = base_name.with_name( + session_name + "_motion_tracking.h5" + ) + + files = { + "pixels": pixels, + "behaviour": behaviour, + } return files @@ -127,7 +236,7 @@ def original_name(path): Get the original name of a file, uncompressed, as a pathlib.Path. """ name = os.path.basename(path) - if name.endswith('.tar.gz'): + if name.endswith(".tar.gz"): name = name[:-7] return Path(name) @@ -172,13 +281,13 @@ def read_bin(path, num_chans, channel=None): Returns ------- numpy.memmap array : A 2D memory-mapped array containing containing the binary - file's data. + file"s data. """ if not isinstance(num_chans, int): num_chans = int(num_chans) - mapping = np.memmap(path, np.int16, mode='r').reshape((-1, num_chans)) + mapping = np.memmap(path, np.int16, mode="r").reshape((-1, num_chans)) if channel is not None: mapping = mapping[:, channel] @@ -227,10 +336,10 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): Parameters ---------- video : numpy.ndarray, or generator - Video data to save to file. It's dimensions should be (duration, height, width) + Video data to save to file. It"s dimensions should be (duration, height, width) and data should be of uint8 type. The file extension determines the resultant file type. Alternatively, this can be a generator that yields frames of this - description, in which case 'dims' must also be passed. + description, in which case "dims" must also be passed. path : string / pathlib.Path object File to which the video will be saved. @@ -239,7 +348,7 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): The frame rate of the output video. dims : (int, int) - (height, width) of video. This is only needed if 'video' is a generator that + (height, width) of video. This is only needed if "video" is a generator that yields frames, as then the shape cannot be taken from it directly. """ @@ -252,8 +361,8 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): process = ( ffmpeg - .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=frame_rate) - .output(path.as_posix(), pix_fmt='yuv420p', r=frame_rate, crf=0, vcodec='libx264') + .input("pipe:", format="rawvideo", pix_fmt="rgb24", s=f"{width}x{height}", r=frame_rate) + .output(path.as_posix(), pix_fmt="yuv420p", r=frame_rate, crf=0, vcodec="libx264") .overwrite_output() .run_async(pipe_stdin=True) ) @@ -274,7 +383,7 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): raise PixelsError(f"Video creation failed: {path}") -def read_hdf5(path): +def read_hdf5(path, key="df"): """ Read a dataframe from a h5 file. @@ -285,14 +394,17 @@ def read_hdf5(path): Returns ------- - pandas.DataFrame : The dataframe stored within the hdf5 file under the name 'df'. + pandas.DataFrame : The dataframe stored within the hdf5 file under the name "df". """ - df = pd.read_hdf(path, 'df') + df = pd.read_hdf( + path_or_buf=path, + key=key, + ) return df -def write_hdf5(path, df): +def write_hdf5(path, df, key="df", mode="w", format="fixed"): """ Write a dataframe to an h5 file. @@ -304,18 +416,36 @@ def write_hdf5(path, df): df : pd.DataFrame Dataframe to save to h5. - """ - df.to_hdf(path, 'df', mode='w') + key : str + identifier for the group in the store. + Default: "df". + + mode : str + mode to open file. + Default: "w" write. + Options: + "a": append, if file does not exists it is created. + "r+": similar to "a" but file must exists. + """ + df.to_hdf( + path_or_buf=path, + key=key, + mode=mode, + format=format, + complevel=9, + #complib="bzip2", # slower but higher compression ratio + complib="blosc:lz4hc", + ) - print('HDF5 saved to ', path) + print("HDF5 saved to", path) return -def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): +def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): """ Get a list of recording sessions for the specified mice, excluding those whose - metadata contain '"exclude" = True'. + metadata contain "'exclude' = True". Parameters ---------- @@ -342,15 +472,39 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): if not isinstance(mouse_ids, (list, tuple, set)): mouse_ids = [mouse_ids] sessions = {} - raw_dir = data_dir / 'raw' + raw_dir = data_dir / "raw" - for mouse in mouse_ids: - mouse_sessions = list(raw_dir.glob(f'*{mouse}*')) + for m, mouse in enumerate(mouse_ids): + mouse_sessions = sorted(list(raw_dir.glob(f"*{mouse}"))) if not mouse_sessions: - print(f'Found no sessions for: {mouse}') + print(f"Found no sessions for: {mouse}") continue + # allows different session date formats + session_dates = sorted([ + datetime.datetime.strptime(s.stem.split("_")[0], session_date_fmt) + for s in mouse_sessions + ]) + + if of_date is not None: + if isinstance(of_date, str): + date_list = [of_date] + elif is_nested_list(of_date): + date_list = of_date[m] + else: + date_list = of_date + + date_sessions = [] + for date in date_list: + date_struct = datetime.datetime.strptime(date, session_date_fmt) + date_sessions.append(mouse_sessions[session_dates.index(date_struct)]) + logging.info( + f"\n> Getting one session from {mouse} on " + f"{datetime.datetime.strftime(date_struct, '%Y %B %d')}." + ) + mouse_sessions = date_sessions + if not meta_dir: # Do not collect metadata for session in mouse_sessions: @@ -362,43 +516,36 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): data_dir=data_dir, )) continue - - meta_file = meta_dir / (mouse + '.json') - with meta_file.open() as fd: - mouse_meta = json.load(fd) - # az: change date format into yyyymmdd - session_dates = [ - datetime.datetime.strptime(s.stem.split("_")[0], session_date_fmt) for s in mouse_sessions - ] - - if len(session_dates) != len(set(session_dates)): - raise PixelsError(f"{mouse}: Data folder dates must be unique.") - - included_sessions = set() - for i, session in enumerate(mouse_meta): - try: - meta_date = datetime.datetime.strptime(session['date'], '%Y-%m-%d') - except ValueError: - # also allow this format - meta_date = datetime.datetime.strptime(session['date'], '%Y%m%d') - except TypeError: - raise PixelsError(f"{mouse} session #{i}: 'date' not found in JSON.") - - for index, ses_date in enumerate(session_dates): - if ses_date == meta_date and not session.get('exclude', False): - name = mouse_sessions[index].stem - if name not in sessions: - sessions[name] = [] - sessions[name].append(dict( - metadata=session, - data_dir=data_dir, - )) - included_sessions.add(name) - - if included_sessions: - print(f'{mouse} has {len(included_sessions)} sessions:', ", ".join(included_sessions)) else: - print(f'No session dates match between folders and metadata for: {mouse}') + meta_file = meta_dir / (mouse + ".json") + with meta_file.open() as fd: + mouse_meta = json.load(fd) + + if len(session_dates) != len(set(session_dates)): + raise PixelsError(f"{mouse}: Data folder dates must be unique.") + + included_sessions = set() + for i, session in enumerate(mouse_meta): + try: + meta_date = datetime.datetime.strptime(session["date"], session_date_fmt) + except TypeError: + raise PixelsError(f"{mouse} session #{i}: 'date' not found in JSON.") + + for index, ses_date in enumerate(session_dates): + if ses_date == meta_date and not session.get("exclude", False): + name = mouse_sessions[index].stem + if name not in sessions: + sessions[name] = [] + sessions[name].append(dict( + metadata=session, + data_dir=data_dir, + )) + included_sessions.add(name) + + if included_sessions: + print(f"{mouse} has {len(included_sessions)} sessions:", ", ".join(included_sessions)) + else: + print(f"No session dates match between folders and metadata for: {mouse}") return sessions @@ -508,8 +655,8 @@ def tdms_to_video(tdms_path, meta_path, output_path): process = ( ffmpeg - .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=fps) - .output(output_path.as_posix(), pix_fmt='yuv420p', r=fps, crf=20, vcodec='libx264') + .input("pipe:", format="rawvideo", pix_fmt="rgb24", s=f"{width}x{height}", r=fps) + .output(output_path.as_posix(), pix_fmt="yuv420p", r=fps, crf=20, vcodec="libx264") .overwrite_output() .run_async(pipe_stdin=True) ) @@ -643,3 +790,133 @@ def stream_video(video, length=None): length -= 1 if length == 0: break + +def reindex_by_longest(dfs, idx_names=None, col_names=None, level=0, sort=True, + return_format="array"): + """ + params + === + dfs: dict, dictionary with pandas dataframe as values. + + return_format: str, format of return value. + "array": stacked np array. + "dataframe": concatenated pandas dataframe. + + names: str or list of str, names for levels of concatenated dataframe. + + return + === + np.array or pd.DataFrame. + """ + if return_format == "array": + # align all trials by index + indices = list(set().union( + *[df.index for df in dfs.values()]) + ) + # reindex by the longest + reidx_dfs = {key: df.reindex(index=indices) + for key, df in dfs.items()} + # stack df values into np array + # NOTE: this create multidimensional data, different from if return + # format is df! + output = np.stack( + [df.values for df in reidx_dfs.values()], + axis=-1, + ) + + elif return_format == "dataframe": + if isinstance(dfs, dict): + # stack dfs vertically + stacked_df = pd.concat(dfs, axis=0) + # set index name + if idx_names: + stacked_df.index.names = idx_names + if col_names: + stacked_df.columns.names = col_names + elif isinstance(dfs, pd.DataFrame): + stacked_df = dfs + + # unstack df at level + output = stacked_df.unstack(level=level, sort=sort) + del stacked_df + + return output + +def is_nested_dict(d): + """ + Returns True if at least one value in dictionary d is a dict. + """ + return any(isinstance(v, dict) for v in d.values()) + + +def is_nested_list(ls): + """ + Returns True if at least one value in list d is a list. + """ + return any(isinstance(item, list) for item in ls) + + +def save_index_to_frame(df, path): + idx_df = df.index.to_frame(index=False) + write_hdf5( + path=path, + df=idx_df, + key="multiindex", + ) + # NOTE: to reconstruct: + # recs_idx = pd.MultiIndex.from_frame(idx_df) + + +def save_cols_to_frame(df, path): + col_df = df.columns.to_frame(index=False) + write_hdf5( + path=path, + df=col_df, + key="cols", + ) + # NOTE: to reconstruct: + # df.columns = col_df.values + + +def get_aligned_data_across_sessions(trials, key, level_names): + """ + Get aligned trials across sessions. + + params + === + trials: nested dict, aligned trials from multiple sessions. + keys: session_name -> stream_id -> "fr", "positions", "spiked" + + key: str, type of data to get. + "fr": firing rate, longest trial time x (unit x trial) + "spiked": spiked boolean, longest trial time x (unit x trial) + "positions": trial positions, longest trial time x trial + + return + === + df: pandas df, concatenated key data across sessions. + """ + per_session = {} + for s_name, s_data in trials.items(): + key_data = {} + for stream_id, stream_data in s_data.items(): + key_data[stream_id] = stream_data[key] + + # concat at stream level + per_session[s_name] = pd.concat( + key_data, + axis=1, + names=level_names[1:], + ) + + # concat at session level + df = pd.concat( + per_session, + axis=1, + names=level_names, + ) + + # swap stream and session so that stream is the most outer level + output = df.swaplevel("session", "stream", axis=1) + + return output diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py new file mode 100644 index 0000000..06da97b --- /dev/null +++ b/pixels/pixels_utils.py @@ -0,0 +1,1973 @@ +""" +This module provides utilities for pixels data. +""" +# annotations not evaluated at runtime +from __future__ import annotations + +import multiprocessing as mp +from multiprocessing import shared_memory +from concurrent.futures import ProcessPoolExecutor, as_completed +import json +from pathlib import Path +import zarr +import gc + +import xarray as xr +from numcodecs import Blosc, VLenUTF8 + +import numpy as np +import pandas as pd + +from scipy import stats +import statsmodels.formula.api as smf +from statsmodels.stats.multitest import multipletests +from patsy import build_design_matrices + +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.sorters as ss +import spikeinterface.curation as sc +import spikeinterface.exporters as sexp +import spikeinterface.preprocessing as spre +import spikeinterface.postprocessing as spost +import spikeinterface.qualitymetrics as sqm + +import pixels.signal_utils as signal +from pixels.ioutils import write_hdf5, reindex_by_longest +from pixels.error import PixelsError +from pixels.configs import * +from pixels.constants import * +from pixels.decorators import _df_to_zarr_via_xarray + +from common_utils import math_utils + +def load_raw(paths, stream_id): + """ + Load raw recording file from spikeglx. + """ + recs = [] + for p, path in enumerate(paths): + # NOTE: if it is catgt data, pass directly `catgt_ap_data` + logging.info(f"\n> Getting the orignial recording...") + # load # recording # file + rec = se.read_spikeglx( + folder_path=path.parent, + stream_id=stream_id, + stream_name=path.stem, + all_annotations=True, # include # all # annotations + ) + recs.append(rec) + if len(recs) > 1: + # concatenate # runs # for # each # probe + concat_recs = si.concatenate_recordings(recs) + else: + concat_recs = recs[0] + + return rec + + +def preprocess_raw(rec, surface_depths, faulty_channels): + group_ids = rec.get_channel_groups() + + if np.unique(group_ids).size < 4: + # correct group id if not all shanks used + group_ids = correct_group_id(rec) + # change the group id + rec.set_channel_groups(group_ids) + + if not np.all(group_ids == group_ids[0]): + # if more than one shank used + preprocessed = [] + # split by groups + groups = rec.split_by("group") + for g, group in groups.items(): + logging.info(f"\n> Preprocessing shank {g}") + # get brain surface depth of shank + surface_depth = surface_depths[g] + cleaned = _preprocess_raw(group, surface_depth, faulty_channels[g]) + preprocessed.append(cleaned) + # aggregate groups together + preprocessed = si.aggregate_channels(preprocessed) + else: + # if only one shank used, check which shank + unique_id = np.unique(group_ids)[0] + # get brain surface depth of shank + surface_depth = surface_depths[unique_id] + # preprocess + preprocessed = _preprocess_raw( + rec, + surface_depth, + faulty_channels[unique_id], + ) + + return preprocessed + + +def _preprocess_raw(rec, surface_depth, faulty_channels): + """ + Implementation of preprocessing on raw pixels data. + """ + # correct phase shift + print("\t> step 1: do phase shift correction.") + rec_ps = spre.phase_shift(rec) + + # remove bad channels from sorting + print("\t> step 2: remove bad channels.") + # remove pre-identified bad channels + chan_names = rec_ps.get_property("channel_name") + faulty_ids = rec_ps.channel_ids[np.isin(chan_names, faulty_channels)] + rec_removed = rec_ps.remove_channels(faulty_ids) + + # detect bad channels + bad_chan_ids, chan_labels = spre.detect_bad_channels( + rec_removed, + outside_channels_location="top", + ) + labels, counts = np.unique(chan_labels, return_counts=True) + for label, count in zip(labels, counts): + print(f"\t\t> Found {count} channels labelled as {label}.") + rec_removed = rec_removed.remove_channels(bad_chan_ids) + + # get channel group id and use it to index into brain surface channel depth + shank_id = np.unique(rec_removed.get_channel_groups())[0] + # get channel depths + chan_depths = rec_removed.get_channel_locations()[:, 1] + # get channel ids + chan_ids = rec_removed.channel_ids + # remove channels outside by using identified brain surface depths + outside_chan_ids = chan_ids[chan_depths > surface_depth] + rec_clean = rec_removed.remove_channels(outside_chan_ids) + print( + f"\t\t> Removed {outside_chan_ids.size} outside channels " + f"above {surface_depth}um." + ) + + return rec_clean + + +def CMR(rec, dtype=np.int16): + cmr = spre.common_reference( + rec, + operator="median", + dtype=dtype, + ) + return cmr + + +def CAR(rec, dtype=np.int16): + car = spre.common_reference( + rec, + operator="average", + dtype=np.int16, + ) + return car + + +def correct_lfp_motion(rec, mc_method="dredge"): + if mc_method == "dredge": + em_method = mc_method+"_lfp" + else: + em_method = spre.motion.motion_options_preset[mc_method][ + "estimate_motion_kwargs" + ]["method"] + raise NotImplementedError("> Not implemented.") + + +def correct_ap_motion(rec, mc_method="dredge"): + """ + Correct motion of recording. + + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) + "ks": let kilosort do motion correction. + + return + === + None + """ + logging.info(f"\n> Correcting motion with {mc_method}.") + + if mc_method == "dredge": + em_method = mc_method+"_ap" + else: + em_method = spre.motion.motion_options_preset[mc_method][ + "estimate_motion_kwargs" + ]["method"] + + # reduce spatial window size for four-shank + estimate_motion_kwargs = { + "method": f"{em_method}", + "win_step_um": 100, + "win_margin_um": -150, + "verbose": True, + } + + # make sure recording dtype is float for interpolation + interpolate_motion_kwargs = { + "dtype": np.float32, + } + + mcd = spre.correct_motion( + rec, + preset=mc_method, + estimate_motion_kwargs=estimate_motion_kwargs, + interpolate_motion_kwargs=interpolate_motion_kwargs, + ) + + # convert to int16 to save space + if not mcd.dtype == np.dtype("int16"): + mcd = spre.astype(mcd, dtype=np.int16) + + return mcd + + +def detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): + """ + Get a sense of possible drifts in the recordings by looking at a + "positional raster plot", i.e. the depth of the spike as function of + time. To do so, we need to detect the peaks, and then to localize them + in space. + + params + === + rec: spikeinterface recording extractor. + + loc_method: str, peak location method. + Default: "monopolar_triangulation" + list of methods: + "center_of_mass", "monopolar_triangulation", "grid_convolution" + to learn more, check: + https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html + """ + shank_groups = rec.get_channel_groups() + level_names = ["shank", "spike_properties"] + + if not np.all(shank_groups == shank_groups[0]): + # split by groups + groups = rec.split_by("group") + dfs = [] + for g, group in groups.items(): + logging.info(f"\n> Estimate drift of shank {g}") + dfs.append(_detect_n_localise_peaks(group, loc_method)) + # concat shanks + df = pd.concat( + dfs, + axis=1, + keys=groups.keys(), + names=level_names, + ) + else: + df = _detect_n_localise_peaks(rec, loc_method) + # add shank level on top + shank_id = shank_groups[0] + df.columns = pd.MultiIndex.from_tuples( + [(shank_id, col) for col in df.columns], + names=level_names, + ) + + return df + + +def _detect_n_localise_peaks(rec, loc_method): + """ + implementation of drift estimation. + """ + from spikeinterface.sortingcomponents.peak_detection\ + import detect_peaks + from spikeinterface.sortingcomponents.peak_localization\ + import localize_peaks + + logging.info("\n> step 1: detect peaks") + peaks = detect_peaks( + recording=rec, + method="by_channel", + detect_threshold=5, + exclude_sweep_ms=0.2, + ) + + logging.info( + "\n> step 2: localize the peaks to get a sense of their putative " + "depths" + ) + peak_locations = localize_peaks( + recording=rec, + peaks=peaks, + method=loc_method, + ) + + # get sampling frequency + fs = rec.sampling_frequency + + # save it as df + df_peaks = pd.DataFrame(peaks) + df_peak_locs = pd.DataFrame(peak_locations) + df = pd.concat([df_peaks, df_peak_locs], axis=1) + # add timestamps and channel ids + df["timestamp"] = df.sample_index / fs + df["channel_id"] = rec.get_channel_ids()[df.channel_index.values] + + return df + + +def extract_band(rec, freq_min, freq_max, ftype="butter"): + """ + Band pass filter recording. + + params + === + freq_min: float, high-pass cutoff corner frequency. + + freq_max: float, low-pass cutoff corner frequency. + + ftype: str, filter type. + since its posthoc, we use 5th order acausal filter, and takes + second-order sections (SOS) representation of the filter, + forward-backward. but more filters to choose from, e.g., bessel with + filter_order=2, presumably preserves waveform better? see lussac. + + return + === + band: spikeinterface recording object. + """ + band = spre.bandpass_filter( + rec, + freq_min=freq_min, + freq_max=freq_max, + margin_ms=5.0, + filter_order=5, + ftype=ftype, + direction="forward-backward", + ) + + return band + + +def whiten(rec): + whitened = spre.whiten( + recording=rec, + dtype=np.float32, + #dtype=np.int16, + #int_scale=200, # scale traces value to sd of 200, in line with ks4 + mode="local", + radius_um=240.0, # 16 nearby chans in line with ks4 + ) + + return whitened + + +def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params, + per_shank=False): + """ + Sort spikes with kilosort 4, curate sorting, save sorting analyser to disk, + and export results to disk. + + params + === + rec: spikeinterface recording object. + + sa_rec: spikeinterface recording object for creating sorting analyser. + + output: path object, directory of output. + + curated_sa_dir: path object, directory to save curated sorting analyser. + + ks_image_path: path object, directory of local kilosort 4 singularity image. + + ks4_params: dict, parameters for kilosort 4. + + per_shank: bool, whether to sort recording per shank. + Default: False (as of may 2025, sort shanks separately by ks4 gives less + units) + + return + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + """ + + # sort spikes + if np.unique(rec.get_channel_groups()).size > 1 and per_shank: + # per shank + sorting, recording = _sort_spikes_by_group( + rec, + sa_rec, + output, + ks_image_path, + ks4_params, + ) + else: + # all together + sorting, recording = _sort_spikes( + rec, + sa_rec, + output, + ks_image_path, + ks4_params, + ) + + # curate sorting + sa, curated_sa = _curate_sorting( + sorting, + recording, + output, + ) + + # export sorting analyser + _export_sorting_analyser( + sa, + curated_sa, + output, + curated_sa_dir, + ) + + return None + + +def _sort_spikes_by_group(rec, sa_rec, output, ks_image_path, ks4_params): + """ + Sort spikes with kilosort 4 by group/shank. + + params + === + rec: spikeinterface recording object. + + sa_rec: spikeinterface recording object for creating sorting analyser. + if None, then use the temp_wh.dat from ks output. + + output: path object, directory of output. + + ks_image_path: path object, directory of local kilosort 4 singularity image. + + ks4_params: dict, parameters for kilosort 4. + + return + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + """ + logging.info("\n> Sorting spikes per shank.") + + # run sorter per shank + sorting = ss.run_sorter_by_property( + sorter_name="kilosort4", + recording=rec, + grouping_property="group", + folder=output, + singularity_image=ks_image_path, + remove_existing_folder=True, + verbose=True, + **ks4_params, + ) + + if not sa_rec: + recs = [] + groups = rec.split_by("group") + for g, group in groups.items(): + ks_preprocessed = se.read_binary( + file_paths=output/f"{g}/sorter_output/temp_wh.dat", + sampling_frequency=group.sampling_frequency, + dtype=np.int16, + num_channels=group.get_num_channels(), + is_filtered=True, + ) + + # attach probe # to ks4 preprocessed recording, from the raw + with_probe = ks_preprocessed.set_probe(group.get_probe()) + # set properties to make sure sorting & sorting sa have all + # probe # properties to form correct rec_attributes, esp + with_probe._properties = group._properties + + # >>> annotations >>> + annotations = group.get_annotation_keys() + annotations.remove("is_filtered") + for ann in annotations: + with_probe.set_annotation( + annotation_key=ann, + value=group.get_annotation(ann), + overwrite=True, + ) + # <<< annotations <<< + recs.append(with_probe) + + recording = si.aggregate_channels(recs) + else: + recording = sa_rec + + return sorting, recording + + +def _sort_spikes(rec, sa_rec, output, ks_image_path, ks4_params): + """ + Sort spikes with kilosort 4. + + params + === + rec: spikeinterface recording object. + + sa_rec: spikeinterface recording object for creating sorting analyser. + + output: path object, directory of output. + + ks_image_path: path object, directory of local kilosort 4 singularity image. + + ks4_params: dict, parameters for kilosort 4. + + return + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + """ + logging.info("\n> Sorting spikes.") + + # run sorter + sorting = ss.run_sorter( + sorter_name="kilosort4", + recording=rec, + folder=output, + singularity_image=ks_image_path, + remove_existing_folder=True, + verbose=True, + **ks4_params, + ) + + # NOTE: may 20 2025 + # build sa with non-whitened preprocessed rec gives amp between 0-250uV, + # which makes sense, and quality metric amp_median is comparable across + # recordings + if not sa_rec: + # load ks preprocessed recording for # sorting analyser + ks_preprocessed = se.read_binary( + file_paths=output/"sorter_output/temp_wh.dat", + sampling_frequency=rec.sampling_frequency, + dtype=np.int16, + num_channels=rec.get_num_channels(), + is_filtered=True, + ) + # attach probe # to ks4 preprocessed recording, from the raw + recording = ks_preprocessed.set_probe(rec.get_probe()) + # set properties to make sure sorting & sorting sa have all + # probe # properties to form correct rec_attributes, esp + recording._properties = rec._properties + + # >>> annotations >>> + annotations = rec.get_annotation_keys() + annotations.remove("is_filtered") + for ann in annotations: + recording.set_annotation( + annotation_key=ann, + value=rec.get_annotation(ann), + overwrite=True, + ) + # <<< annotations <<< + else: + recording = sa_rec + + return sorting, recording + + +def _curate_sorting(sorting, recording, output): + """ + Curate spike sorting results, and export to disk. + + params + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + + output: path object, directory of output. + + return + === + sa: spikeinterface sorting analyser. + + curated_sa: curated spikeinterface sorting analyser. + """ + logging.info("\n> Curating sorting.") + + # curate sorter output + # remove spikes exceeding recording number of samples + sorting = sc.remove_excess_spikes(sorting, recording) + # remove duplicate spikes + sorting = sc.remove_duplicated_spikes( + sorting, + censored_period_ms=0.3, + method="keep_first_iterative", + ) + # remove redundant units created by ks + sorting = sc.remove_redundant_units( + sorting, + duplicate_threshold=0.9, # default is 0.8 + align=False, + remove_strategy="max_spikes", + ) + + # create sorting analyser + sa = si.create_sorting_analyzer( + sorting=sorting, + recording=recording, + sparse=True, + format="zarr", + folder=output/"sa.zarr", + overwrite=True, + ) + + # calculate all extensions BEFORE further steps + # list required extensions for redundant units removal and quality + # metrics + required_extensions = [ + "random_spikes", + "waveforms", + "templates", + "noise_levels", + "unit_locations", + "template_similarity", + "spike_amplitudes", + "correlograms", + "principal_components", # for phy + ] + sa.compute(required_extensions, save=True) + + # make sure to have group id for each unit + if not "group" in sa.sorting.get_property_keys(): + # get shank id, i.e., group + group = sa.recording.get_channel_groups() + # get max peak channel for each unit + max_chan = si.get_template_extremum_channel(sa).values() + # get group id for each unit + try: + unit_group = group[list(max_chan)] + except IndexError: + unit_group = group[sa.channel_ids_to_indices(max_chan)] + # set unit group as a property for sorting + sa.sorting.set_property( + key="group", + values=unit_group, + ) + else: + # get max peak channel for each unit + max_chan = si.get_template_extremum_channel(sa).values() + + # calculate quality metrics + qms = sqm.compute_quality_metrics(sa) + + # >>> get depth of units on each shank >>> + # get probe geometry coordinates + coords = sa.get_channel_locations() + # get coordinates of max channel of each unit on probe, column 0 is + # x-axis, column 1 is y-axis/depth, 0 at bottom-left channel. + max_chan_idx = sa.channel_ids_to_indices(max_chan) + max_chan_coords = coords[max_chan_idx] + # set coordinates of max channel of each unit as a property of sorting + sa.sorting.set_property( + key="max_chan_coords", + values=max_chan_coords, + ) + # <<< get depth of units on each shank <<< + + # remove bad units + #rule = "sliding_rp_violation <= 0.1 & amplitude_median <= -40\ + # & amplitude_cutoff < 0.05 & sd_ratio < 1.5 & presence_ratio > 0.9\ + # & snr > 1.1 & rp_contamination < 0.2 & firing_rate > 0.1" + # use the ibl methods, but amplitude_cutoff rather than noise_cutoff + qms_rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -40\ + & presence_ratio > 0.9" + good_qms = qms.query(qms_rule) + logging.info( + "> quality metrics check removed " + f"{np.setdiff1d(sa.unit_ids, good_qms.index.values)}." + ) + # TODO nov 26 2024 + # wait till noise cutoff implemented and include that. + # also see why sliding rp violation gives loads nan. + + # calculate template metrics + tms = spost.compute_template_metrics( + sa, + include_multi_channel_metrics=True, + ) + # remove noise based on waveform + tms_rule = "num_positive_peaks <= 2 & num_negative_peaks == 1 &\ + exp_decay > 0.01 & exp_decay < 0.1" # bombcell + #peak_to_valley > 0.00018 &\ + good_tms = tms.query(tms_rule) + logging.info( + "> Template metrics check removed " + f"{np.setdiff1d(sa.unit_ids, good_tms.index.values)}." + ) + + # get good units that passed quality metrics & template metrics + good_units = np.intersect1d(good_qms.index.values, good_tms.index.values) + good_unit_mask = np.isin(sa.unit_ids, good_units) + + # get template of each unit on its max channel + templates = sa.load_extension("templates").get_data() + unit_idx = sa.sorting.ids_to_indices(good_units) + max_chan_templates = templates[unit_idx, :, max_chan_idx[good_unit_mask]] + + # filter non somatic units by waveform analysis + soma_mask = filter_non_somatics( + sa.unit_ids, + max_chan_templates, + sa.sampling_frequency, + ) + soma_units = good_units[soma_mask] + + # get unit ids + curated_unit_ids = np.intersect1d(good_units, soma_units) + # select curated + curated_sorting = sa.sorting.select_units(curated_unit_ids) + curated_sa = sa.select_units(curated_unit_ids) + # reattach curated sorting to curated_sa to keep sorting properties + curated_sa.sorting = curated_sorting + + return sa, curated_sa + + +def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir, + to_phy=False): + """ + Export sorting analyser to disk. + + params + === + sa: spikeinterface sorting analyser. + + curated_sa_dir: path object, directory to save curated sorting analyser. + + output: path object, directory of output. + + return + === + None + """ + logging.info("\n> Exporting sorting results.") + + # export pre curation report + sexp.export_report( + sorting_analyzer=sa, + output_folder=output/"report", + ) + + # export curated report + sexp.export_report( + sorting_analyzer=curated_sa, + output_folder=output/"curated_report", + ) + + # save sa to disk + curated_sa.save_as( + format="zarr", + folder=curated_sa_dir, + ) + + if to_phy: + export_sa_to_phy(output, sa) + + return None + + +def export_sa_to_phy(path, sa): + # export to phy for additional manual curation if needed + sexp.export_to_phy( + sorting_analyzer=sa, + output_folder=path/"phy", + copy_binary=False, + ) + + return None + + +def _permute_spikes_n_convolve_fr(array, sigma, sample_rate): + """ + Randomly permute spike boolean across time. + + params + === + array: 2D np array, time points x units. + + sigma: int/float, time in millisecond of sigma of gaussian kernel for firing + rate convolution. + + sample_rate: float/int, sampling rate of signal. + + return + === + random_spiked: shuffled spike boolean for each unit. + + random_fr: convolved firing rate from shuffled spike boolean for each unit. + """ + # initiate random number generator every time to avoid same results from the + # same seeding + rng = np.random.default_rng() + # permutate columns + random_spiked = rng.permuted(array, axis=0) + # convolve into firing rate + random_fr = signal.convolve_spike_trains( + times=random_spiked, + sigma=sigma, + sample_rate=sample_rate, + ) + + return random_spiked, random_fr + + +def _worker_write_repeat( + i, zarr_path, sigma, sample_rate, shm_name, shape, dtype_str, +): + # attach to shared memory + shm = shared_memory.SharedMemory(name=shm_name) + try: + spiked = np.ndarray( + shape, + dtype=np.dtype(dtype_str), + buffer=shm.buf, + ) + # avoid accidental in-place mutation + spiked.setflags(write=False) + # get permuted data + c_spiked, c_fr = _permute_spikes_n_convolve_fr(spiked, sigma, sample_rate) + + # child process re-opens the store to avoid pickling big arrays + store = zarr.DirectoryStore(zarr_path) + root = zarr.open_group(store=store, mode="a") + + # Write the i-th slice along last axis + root["chance_spiked"][..., i] = c_spiked + root["chance_fr"][..., i] = c_fr + del c_spiked, c_fr + gc.collect() + + logging.info(f"\nRepeat {i} finished.") + finally: + shm.close() + + return None + + +def save_spike_chance_zarr( + zarr_path, + spiked: np.ndarray, + sigma: float, + sample_rate: float, + repeats: int = 100, + positions=None, + meta: dict | None = None, +): + """ + Create a Zarr store at `zarr_path` with datasets: + - spiked: base spiked array (read-only reference) + - chance_spiked: base_shape + (repeats,), int16 + - chance_fr: base_shape + (repeats,), float32 + - positions: optional small array (or vector), stored if provided + Then fill each repeat slice in parallel processes. + + This function is idempotent: if the target datasets exist and match shape, it skips creation. + """ + n_workers = 2 ** (mp.cpu_count().bit_length() - 2) + + zarr_path = Path(zarr_path) + zarr_path.parent.mkdir(parents=True, exist_ok=True) + + # Convert to array ONCE for shared memory + if isinstance(spiked, pd.DataFrame): + spike_arr = np.ascontiguousarray(spiked.to_numpy()) + else: + spike_arr = np.ascontiguousarray(np.asarray(spiked)) + + base_shape = spiked.shape + d_shape = base_shape + (repeats,) + + chunks = tuple(min(s, BIG_CHUNKS) for s in base_shape) + (1,) + + store = zarr.DirectoryStore(str(zarr_path)) + root = zarr.group(store=store, overwrite=not zarr_path.exists()) + + # Metadata + root.attrs["kind"] = "spike_chance" + root.attrs["sigma"] = float(sigma) + root.attrs["sample_rate"] = float(sample_rate) + root.attrs["repeats"] = int(repeats) + if meta: + for k, v in meta.items(): + root.attrs[k] = v + + logging.info(f"\n> Creating zarr dataset.") + # Base source so workers can read it without pickling + if "spiked" in root: + del root["spiked"] + + if isinstance(spiked, pd.DataFrame): + _df_to_zarr_via_xarray( + df=spiked, + store=store, + group_name="spiked", + compressor=compressor, + mode="w", + ) + else: + root.create_dataset( + "spiked", + data=spiked, + chunks=chunks[:-1], + compressor=compressor, + ) + del spiked + gc.collect() + + # Outputs + if "chance_spiked" in root\ + and tuple(root["chance_spiked"].shape) != d_shape: + del root["chance_spiked"] + if "chance_fr" in root and tuple(root["chance_fr"].shape) != d_shape: + del root["chance_fr"] + + if "chance_spiked" not in root: + root.create_dataset( + "chance_spiked", + shape=d_shape, + dtype="bool", + chunks=chunks, + compressor=compressor, + ) + if "chance_fr" not in root: + root.create_dataset( + "chance_fr", + shape=d_shape, + dtype="float32", + chunks=chunks, + compressor=compressor, + ) + + # save positions + if positions is not None: + if "positions" in root: + del root["positions"] + + if isinstance(positions, pd.DataFrame): + _df_to_zarr_via_xarray( + df=positions, + store=store, + group_name="positions", + compressor=compressor, + mode="w", + ) + else: + root.create_dataset( + "positions", + data=positions, + chunks=True, + compressor=compressor, + ) + + del positions + gc.collect() + + # Create shared memory once + shm = shared_memory.SharedMemory( + create=True, + size=spike_arr.nbytes, + ) + try: + shm_arr = np.ndarray( + base_shape, + dtype=spike_arr.dtype, + buffer=shm.buf, + ) + shm_arr[...] = spike_arr + + logging.info(f"\n> Starting process pool.") + # Pass only the metadata needed to reconstruct the view + dtype_str = spike_arr.dtype.str # portable dtype spec + + # Parallel fill: each worker writes a distinct final-axis slice + with ProcessPoolExecutor(max_workers=n_workers) as ex: + futures = [ + ex.submit( + _worker_write_repeat, + i, + str(zarr_path), + sigma, + sample_rate, + shm.name, + base_shape, + dtype_str, + ) for i in range(repeats) + ] + for f in as_completed(futures): + f.result() # raise on error + finally: + shm.close() + shm.unlink() + + return None + + +def bin_spike_chance(chance_data, sample_rate, time_bin, pos_bin, arr_path): + # extract data from chance + chance_spiked = chance_data["chance_spiked"] + chance_fr = chance_data["chance_fr"] + REPEATS = chance_spiked.shape[-1] + + # get index and columns to reconstruct df + spiked = chance_data["spiked"].dropna(axis=0, how="all") + idx = spiked.index + cols = spiked.columns + trial_ids = spiked.index.get_level_values("trial").unique() + # get positions + positions = chance_data["positions"].dropna(axis=1, how="all") + + count_arrs = {} + fr_arrs = {} + count_dfs = {} + fr_dfs = {} + temp_spiked = {} + temp_fr = {} + for repeat in range(REPEATS): + r_fr = pd.DataFrame( + chance_fr[:, :, repeat], + index=idx, + columns=cols, + ) + r_spiked = pd.DataFrame( + chance_spiked[:, :, repeat], + index=idx, + columns=cols, + ) + + temp_spiked[repeat] = {} + temp_fr[repeat] = {} + for trial in trial_ids: + trial_pos = positions.xs(trial, level="trial", axis=1).dropna() + counts = r_spiked.xs(trial, level="trial", axis=0) + fr = r_fr.xs(trial, level="trial", axis=0) + + # bin fr + temp_fr[repeat][trial] = bin_vr_trial( + data=fr, + positions=trial_pos, + sample_rate=sample_rate, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", # fr + ) + # bin spiked + temp_spiked[repeat][trial] = bin_vr_trial( + data=counts, + positions=trial_pos, + sample_rate=sample_rate, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", # spike count + ) + + fr_dfs[repeat] = pd.concat( + temp_fr[repeat], + axis=0, + ) + count_dfs[repeat] = pd.concat( + temp_spiked[repeat], + axis=0, + ) + + # np array for andrew + fr_arrs[repeat] = reindex_by_longest( + dfs=temp_fr[repeat], + return_format="array", + ) + count_arrs[repeat] = reindex_by_longest( + dfs=temp_spiked[repeat], + return_format="array", + ) + + # save np array, for andrew + arr_count_output = np.stack( + list(count_arrs.values()), + axis=-1, + dtype=np.float32, + ) + arr_fr_output = np.stack( + list(fr_arrs.values()), + axis=-1, + dtype=np.float32, + ) + arrs = { + "count": arr_fr_output[:, :-2, ...], + "fr": arr_count_output[:, :-2, ...], + "pos": arr_fr_output[:, -2:, ...], + } + np.savez_compressed(arr_path, **arrs) + + # output df + df_fr = pd.concat( + fr_dfs, + axis=0, + names=["repeat", "trial", "bin_time"], + ).iloc[:, :-2] + temp = pd.concat( + count_dfs, + axis=0, + names=["repeat", "trial", "bin_time"], + ) + df_spiked = temp.iloc[:, :-2] + df_pos = temp.iloc[:, -2:] + + return {"spiked": df_spiked, "fr": df_fr, "positions": df_pos} + + +def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, + bin_method="mean"): + """ + Bin virtual reality trials by given temporal bin and positional bin. + + params + === + data: pandas dataframe, neural data needed binning. + + positions: pandas dataframe, position of current trial. + + time_bin: str, temporal bin for neural data. + + pos_bin: int, positional bin for positions. + + bin_method: str, method to concatenate data within each temporal bin. + "mean": taking the mean of all frames. + "sum": taking sum of all frames. + """ + data = data.copy() + positions = positions.copy() + + # convert index to datetime index for resampling + isi = (1 / sample_rate) * 1000 + data.index = pd.to_timedelta( + arg=data.index * isi, + unit="ms", + ) + + # set position index too + positions.index = data.index + + # resample to ms bin, and get position mean + mean_pos = positions.resample(time_bin).mean() + + if bin_method == "sum": + # resample to Xms bin, and get sum + bin_data = data.resample(time_bin).sum() + elif bin_method == "mean": + # resample to Xms bin, and get mean + bin_data = data.resample(time_bin).mean() + + # add position here to bin together + bin_data['positions'] = mean_pos.values + # add bin positions + bin_pos = mean_pos // pos_bin + 1 + bin_data['bin_pos'] = bin_pos.values + + # use numeric index + bin_data.reset_index(inplace=True, drop=True) + + return bin_data + +def correct_group_id(rec): + # check probe type + ''' + npx 1.0: 0 + npx 2.0 alpha: 24 + npx 2.0 commercial: 2013 + ''' + probe_type = int(rec.get_annotation("probes_info")[0]["probe_type"]) + # double check it is multishank probe + assert probe_type > 0 + + # get channel x locations + shank_x_locs = { + 0: [0, 32], + 1: [250, 282], + 2: [500, 532], + 3: [750, 782], + } + + # get group ids + group_ids = rec.get_channel_groups() + + x_locs = rec.get_channel_locations()[:, 0] + for shank_id, shank_x in shank_x_locs.items(): + # map bool channel x locations + shank_bool = np.isin(x_locs, shank_x) + if np.any(shank_bool) == False: + logging.info( + f"\n> Recording does not have shank {shank_id}, continue." + ) + continue + group_ids[shank_bool] = shank_id + + logging.info( + "\n> Not all shanks used in multishank probe, change group ids into " + f"{np.unique(group_ids)}." + ) + + return group_ids + + +def get_vr_positional_data(trial_data): + """ + Get positional firing rate and spike count for VR behaviour. + + params + === + trial_data: pandas df, output from align_trials. + + return + === + dict, positional firing rate, positional spike count, positional occupancy, + data in 1cm resolution. + """ + # NOTE: take occupancy from spike count since in we might need to + # interpolate fr for binned data + pos_fc, occupancy = _get_vr_positional_neural_data( + positions=trial_data["positions"], + data_type="spiked", + data=trial_data["spiked"], + ) + pos_fr, _ = _get_vr_positional_neural_data( + positions=trial_data["positions"], + data_type="spike_rate", + data=trial_data["fr"], + ) + + return {"pos_fc": pos_fc, "pos_fr": pos_fr, "occupancy": occupancy} + + +def _get_vr_positional_neural_data(positions, data_type, data): + """ + Get positional neural data for VR behaviour. + + params + === + positions: pandas df, vr positions of all trials. + shape: time x trials. + + data_type: str, type of neural data. + "spike_rate": firing rate of each unit in each trial. + "spiked": spike boolean of each unit in each trial. + + data: pandas df, aligned trial firing rate or spike boolean. + shape: time x (unit x trial) + levels: unit, trial + + return + === + pos_data: pandas df, positional neural data. + shape: position x (num of starting positions x unit x trial) + levels: start, unit, trial + + occupancy: pandas df, count of each position. + shape: position x trial + """ + from pandas.api.types import is_integer_dtype + from vision_in_darkness.constants import SPATIAL_SAMPLE_RATE + + if "bin" in positions.index.name: + logging.info(f"\n> Getting binned positional {data_type}...") + # create position indices for binned data + indices_range = [ + positions.min().min(), + positions.max().max()+1, + ] + else: + logging.info(f"\n> Getting positional {data_type}...") + # get constants from vd + from vision_in_darkness.constants import TUNNEL_RESET + # create position indices + indices_range = [np.floor(positions.min().min()), TUNNEL_RESET+1] + + # get trial ids + trial_ids = positions.columns.get_level_values("trial") + + # create position indices + indices = np.arange(*indices_range, SPATIAL_SAMPLE_RATE).astype(int) + # create occupancy array for trials + occupancy = pd.DataFrame( + data=np.full((len(indices), positions.shape[1]), np.nan), + index=indices, + columns=trial_ids, + ) + + pos_data = {} + for t, trial in enumerate(trial_ids): + # get trial position + trial_pos = positions.xs(trial, level="trial", axis=1).dropna() + + # convert to int if float + if not trial_pos.dtypes.map(is_integer_dtype).all(): + # floor position and set to int + trial_pos = trial_pos.apply(lambda x: np.floor(x)).astype(int) + # exclude positions after tunnel reset + trial_pos = trial_pos[trial_pos <= indices[-1]] + + # get firing rates for current trial of all units + try: + trial_data = data.xs( + key=trial, + axis=1, + level="trial", + ).dropna(how="all").copy() + except TypeError: + # chance data has trial and time on index, not column + trial_data = data.T.xs( + key=trial, + axis=1, + level="trial", + ).T.dropna(how="all").copy() + + # get all indices before post reset + no_post_reset = trial_data.index.intersection(trial_pos.index) + # remove post reset rows + trial_data = trial_data.loc[no_post_reset] + trial_pos = trial_pos.loc[no_post_reset] + + # put trial positions in trial data df + trial_data["position"] = trial_pos.values + + if data_type == "spike_rate": + # group values by position and get mean data + how = "mean" + elif data_type == "spiked": + # group values by position and get sum data + how = "sum" + grouped_data = math_utils.group_and_aggregate( + trial_data, + "position", + how, + ) + + # reindex into full tunnel length + reidxed = grouped_data.reindex(indices) + + # check for missing values in binned data + if ("bin" in positions.index.name) and (data_type == "spike_rate"): + # remove alll nan before data actually starts + start_idx = grouped_data.index[0] + chunk_data = reidxed.loc[start_idx:, :] + nan_check = chunk_data.isna().any().any() + if nan_check: + # interpolate missing fr + logging.info(f"\n> trial {trial} has missing values, " + "do linear interpolation.") + reidxed.loc[start_idx:, :] = chunk_data.interpolate( + method="linear", + axis=0, + ) + + # save to dict + pos_data[trial] = reidxed + + # get trial occupancy + pos_count = trial_data.groupby("position").size() + occupancy.loc[pos_count.index.values, trial] = pos_count.values + + # concatenate dfs + pos_data = pd.concat(pos_data, axis=1, names=["trial", "unit"]) + + # add another level of starting position + # group trials by their starting index + trial_level = pos_data.columns.get_level_values("trial") + unit_level = pos_data.columns.get_level_values("unit") + # map start level + starts = positions.columns.get_level_values("start").values + start_series = pd.Series( + data=starts, + index=trial_ids, + name="start", + ) + start_level = trial_level.map(start_series) + + # define new columns + new_cols = pd.MultiIndex.from_arrays( + [start_level, unit_level, trial_level], + names=["start", "unit", "trial"], + ) + pos_data.columns = new_cols + + # sort by unit, starting position, and then trial + pos_data = pos_data.sort_index( + axis=1, + level=["unit", "start", "trial"], + ascending=[True, False, True], + ).dropna(how="all") + + occupancy = occupancy.dropna(how="all") + # remove negative position values + if occupancy.index.min() < 0: + occupancy = occupancy.loc[0:, :] + pos_data = pos_data.loc[0:, :] + + return pos_data, occupancy + + +def get_spatial_psd(pos_fr): + def _compute_psd(col): + x = col.dropna().values.squeeze() + f, psd = math_utils.estimate_power_spectrum(x, use_welch=False) + # remove 0 to avoid infinity + f = f[1:] + psd = psd[1:] + ser = pd.Series(psd, index=f, name=col.name) + ser.index.name = "frequency" + + return ser + + psd = pos_fr.apply(_compute_psd, axis=0) + + return psd + + +def _psd_chance_worker(r, sample_rate, positions, paths, cut_start, cut_end): + """ + Worker that computes one set of psd. + + params + === + i: index of current repeat. + + return + === + """ + logging.info(f"\nProcessing repeat {r}...") + chance_data, idx, cols = get_spike_chance( + sample_rate=sample_rate, + positions=positions, + **paths, + ) + + pos_fr, _ = _get_vr_positional_neural_data( + positions=positions, + data_type="spike_rate", + data=pd.DataFrame(chance_data[..., r], index=idx, columns=cols), + ) + del chance_data, positions + + psd = {} + starts = pos_fr.columns.get_level_values("start").unique() + + for start in starts: + start_df = pos_fr.xs( + start, + level="start", + axis=1, + ).dropna(how="all") + + cropped = start_df.loc[start+cut_start:cut_end, :] + psd[start] = get_spatial_psd(cropped) + + del cropped, start_df + + psd_df = pd.concat( + psd, + names=["start", "frequency"], + ) + + logging.info(f"\nRepeat {r} finished.") + + return psd_df + + +def save_chance_psd(sample_rate, positions, paths):#chance_data, idx, cols): + """ + Implementation of saving chance level spike data. + """ + #import concurrent.futures + from vision_in_darkness.constants import PRE_DARK_LEN, landmarks + + # Set up the process pool to run the worker in parallel. + # Submit jobs for each repeat. + futures = [] + with ProcessPoolExecutor() as executor: + for r in range(REPEATS): + future = executor.submit( + _psd_chance_worker, + r, + sample_rate, + positions, + paths, + PRE_DARK_LEN, + landmarks[-1] - 1 + ) + futures.append(future) + # collect and concat + results = [f.result() for f in as_completed(futures)] + + psds = pd.concat( + results, + axis=1, + keys=range(REPEATS), + names=["repeat", "unit", "trial"], + ) + + return psds + + +def notch_freq(rec, freq, bw=4.0): + """ + Notch a frequency with narrow bandwidth. + + params + === + rec: si recording object. + + freq: float or int, the target frequency in Hz of the notch filter. + + bw: float or int, bandwidth (Hz) of notch filter. + Default: 4.0Hz. + + return + === + notched: spikeinterface recording object. + """ + notched = spre.notch_filter( + rec, + freq=freq, + q=freq/bw, # quality factor + ) + + return notched + + +# >>> landmark responsive helpers >>> +def to_df(mean, std, zone): + out = pd.DataFrame({"mean": mean, "std": std}).reset_index() + + # Keep only start and trial; unit is constant + out = out.rename( + columns={"level_0": "start", "level_1": "unit", "level_2": "trial"} + ) + out["zone"] = zone + + # map data type + out["start"] = out["start"].astype(str) + out["unit"] = out["unit"].astype(str) + out["trial"] = out["trial"].astype(str) + + return out + + +# Build linear contrasts for any model that uses patsy coding +def compute_contrast(fit, newdf_a, newdf_b=None): + """ + Returns estimate and SE for: + L'beta where L = X(newdf_a) - X(newdf_b) if newdf_b is provided, + else L = X(newdf_a) + fit: statsmodels results with fixed effects (OLS or MixedLM) + newdf_a/newdf_b: small DataFrames with columns used in the formula (start, + zone) + """ + # For MixedLM, use fixed-effects params/cov + if hasattr(fit, "fe_params"): + fe_params = fit.fe_params + cov = fit.cov_params().loc[fe_params.index, fe_params.index] + cols = fe_params.index + else: + # OLS + fe_params = fit.params + cov = fit.cov_params() + cols = fit.params.index + + di = fit.model.data.design_info + + Xa = build_design_matrices([di], newdf_a)[0] + Xa = np.asarray(Xa) # column order matches the fit + if Xa.shape[1] != len(cols): + # rare: ensure columns align if needed + raise ValueError( + "Design column mismatch; " + "ensure newdf has the same factors and levels as the fit." + ) + + if newdf_b is not None: + Xb = build_design_matrices([di], newdf_b)[0] + Xb = np.asarray(Xb) + L = (Xa - Xb).ravel() + else: + L = Xa.ravel() + + est = float(L @ fe_params.values) + se = float(np.sqrt(L @ cov.values @ L)) + + return est, se + + +# >>> single unit mixed model +def fit_per_unit_ols(df, formula, unit_id): + """ + Step 1 + Fit mean fr of pre-wall, landmark, and post-wall from each trial, each unit + to GLM with cluster-robust SE. + """ + d = df[df["unit"] == str(unit_id)].copy() + if d.empty: + raise ValueError(f"No data for unit {unit_id}") + # OLS with cluster-robust SE by trial + fit = smf.ols(formula, data=d).fit( + cov_type="cluster", + cov_kwds={"groups": d["trial"]}, + ) + return fit + + +def test_diff_any(fit, starts, use_f=True): + """ + Step 2 + Use Wald test on linear contrasts to test if jointly, all these contrasts + are 0. + i.e., this test if there are any difference among the mean fr comparisons. + if wald p < alpha, then there is a significant difference, we do post-hoc to + see where the difference come from; + if wald p > alpha, the unit does not have different fr between landmark & + pre-wall, and landmark & post-wall, or pre-wall & post-wall. + """ + Ls = [] + for s in starts: + for ref in ["pre_wall", "post_wall"]: + row = _L_row( + fit=fit, + start_label=s, + a_zone="landmark", + b_zone=ref, + ) + Ls.append(row) + + R = np.vstack(Ls) + w = fit.wald_test(R, scalar=True, use_f=use_f) + results = { + "stat": float(w.statistic), + "p": float(w.pvalue), + "df_num": int(getattr(w, "df_num", R.shape[0])), + "df_denom": float(getattr(w, "df_denom", np.nan)), + "k": R.shape[0], + } + + return results["p"] + + +def _L_row(fit, start_label, a_zone, b_zone): + """ + Build linear contrasts of zones for a given starting position. + """ + di = fit.model.data.design_info + Xa = np.asarray( + build_design_matrices( + [di], + pd.DataFrame({"start":[start_label], "zone":[a_zone]}), + )[0] + ).ravel() + Xb = np.asarray( + build_design_matrices( + [di], + pd.DataFrame({"start":[start_label], "zone":[b_zone]}) + )[0] + ).ravel() + + return Xa - Xb + + +def family_comparison(fit, starts, compare_to="pre_wall", use_f=True): + """ + Step 3 + Family level mean comparison, i.e., compare pre or post wall with landmark. + """ + R = np.vstack( + [_L_row(fit, s, "landmark", compare_to) for s in starts] + ) + w = fit.wald_test(R, scalar=True, use_f=use_f) + + results = { + "family": f"LM-{ 'Pre' if compare_to=='pre_wall' else 'Post' }", + "stat": float(w.statistic), + "p": float(w.pvalue), + "df_num": int(getattr(w, "df_num", R.shape[0])), + "df_denom": float(getattr(w, "df_denom", np.nan)), + "n_starts": len(starts), + "R": R + } + return results["p"] + + +def start_contrasts_ols(fit, starts, use_normal=True): + """ + Step 4 + Post-hoc test to see where the difference in contrast come from, i.e., get + the linear contrast for each starting positions. + """ + params = fit.params if not hasattr(fit, "fe_params")\ + else fit.fe_params + cov = fit.cov_params() if not hasattr(fit, "fe_params")\ + else fit.cov_params().loc[params.index, params.index] + + rows = [] + for s in sorted(starts, key=str): + df_pre = pd.DataFrame({"start":[s], "zone":["pre_wall"]}) + df_lm = pd.DataFrame({"start":[s], "zone":["landmark"]}) + df_post = pd.DataFrame({"start":[s], "zone":["post_wall"]}) + + for label, A, B in [("lm-pre", df_lm, df_pre), + ("lm-post", df_lm, df_post)]: + est, se = compute_contrast(fit, A, B) + if not np.isfinite(se) or se <= 0: + stat = p = np.nan + else: + stat = est / se + p = float( + 2 * (stats.norm.sf(abs(stat)) if use_normal + else stats.t.sf(abs(stat), df=fit.df_resid)) + ) + col_names = ["start", "contrast", "coef", "SE", "stat", "p"] + rows.append( + dict(zip(col_names, [s, label, est, se, stat, p])) + ) + + out = pd.DataFrame(rows) + if not out.empty and out["p"].notna().any(): + # get Holm-adjusted p value to correct for multiple comparison + out["p_holm"] = multipletests( + out["p"], + alpha=ALPHA, + method="holm", + )[1] + + # rename 'stat' to 'z' or 't' + stat_label = "z" if use_normal else "t" + out = out.rename(columns={"stat": stat_label}) + + return out + + +def test_start_x_zone_interaction_ols(fit): + # Wald test: all interaction terms = 0 + ix = [i for i, zone in enumerate(fit.params.index) if ":" in zone] + if not ix: + return np.nan + R = np.zeros((len(ix), len(fit.params))) + for r, i in enumerate(ix): + R[r, i] = 1.0 + w = fit.wald_test(R, scalar=True, use_f=False) + stat = w.statistic + + return float(w.statistic), float(w.pvalue) +# <<< single unit mixed model +# <<< landmark responsive helpers <<< + + +def get_landmark_responsives(pos_fr, units, ons, offs): + """ + use int8 to encode responsiveness: + 0: not responsive + 1: positively responsive + -1: negatively responsive + """ + units = units.flat() + + # get all positions + positions = pos_fr.index.to_numpy() + # build mask for all positions + position_mask = (positions[:, None] >= ons)\ + & (positions[:, None] < offs) + + # get pre wall and trial mask + pre_wall = pos_fr.loc[position_mask[:, 0], :] + trials_pre_wall = pre_wall.columns.get_level_values( + "trial" + ).unique() + + # get mean & std of walls and landmark + landmark = pos_fr.loc[position_mask[:, 1], :] + assert (landmark.columns.get_level_values("trial").unique() ==\ + trials_pre_wall).all() + post_wall = pos_fr.loc[position_mask[:, 2], :] + + pre_wall_mean = pre_wall.mean(axis=0) + pre_wall_std = pre_wall.std(axis=0) + + landmark_mean = landmark.mean(axis=0) + landmark_std = landmark.std(axis=0) + + post_wall_mean = post_wall.mean(axis=0) + post_wall_std = post_wall.std(axis=0) + + # aggregate + agg = pd.concat([ + to_df(pre_wall_mean, pre_wall_std, "pre_wall"), + to_df(landmark_mean, landmark_std, "landmark"), + to_df(post_wall_mean, post_wall_std, "post_wall"), + ], + ignore_index=True, + ) + agg["zone"] = pd.Categorical( + agg["zone"], + categories=["pre_wall", "landmark", "post_wall"], + ordered=True, + ) + + # get all starting positions + starts = agg.start.unique() + + # create model formula + min_start = min(starts) + simple_model = ( + "mean ~ C(zone, Treatment(reference='pre_wall'))" + ) + full_model = ( + f"""mean + ~ C(start, Treatment(reference={min_start!r})) + * C(zone, Treatment(reference='pre_wall'))""" + ) + + lm_responsive_bool = np.zeros(len(units)).astype(np.int8) + responsives = pd.Series( + lm_responsive_bool, + index=units, + ) + responsives.index.name = "unit" + + lm_contrasts = {} + for unit_id in units: + unit_fit = fit_per_unit_ols( + df=agg, + formula=full_model, + #formula=simple_model, + unit_id=unit_id, + ) + # check contrast at each start + unit_contrasts = start_contrasts_ols( + fit=unit_fit, + starts=starts, + ) + lm_contrasts[unit_id] = unit_contrasts + + if len(starts) < 2: + logging.info( + f"\n> Skip testing this landmark cuz only start {starts[0]} " + "covers it." + ) + else: + # positive responsive + if (unit_contrasts.coef > 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id] = 1 + # negative responsive + if (unit_contrasts.coef < 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id] = -1 + + contrasts = pd.concat( + lm_contrasts, + axis=0, + names=["unit", "index"], + ).droplevel("index") + + return contrasts, responsives + + +def filter_non_somatics(unit_ids, templates, sampling_freq): + # NOTE: no need to worry about multi-positive-peak templates, cuz we already + # threw them out + from scipy.signal import find_peaks + + # True means yes somatic, False mean non somatic + mask = np.zeros(len(templates), dtype=bool) + + # bombcell non somatic criteria + max_repo_peak_to_trough_ratio = 0.8 + + # height ratio to trough + heigh_ratio_to_trough = 0.15 #0.2 + # minimum width of the peak for detection + min_width = 0 + + ## minimum width of depo peak for filtering + #peak_width_ms = 0.1 # 0.07 + #min_depo_peak_width = int(peak_width_ms / 1000 * sampling_freq) + + # NOTE: since our data is high-pass filtered, the very transient peak before + # could be caused by that, its width does not tell us whether it is a + # somatic unit or not. DO NOT RELY ON PRE TROUGH PEAK TO DECIDE SOMATIC! + # so there are two ways to do this: + # 1. set the detection minimum to min_depo_peak_width like in + # spikeinterface, so that we just ignore those transient peaks and consider + # it as high-pass-filter artefact, has no weight in somatic decision; + # 2. using a very low threshold during detection like what we do here, and + # check the height ratio of the peak, if it exceeds our maximum, still + # consider it as non somatic, even if it is very transient. + + # maximum depo peak to repo peak ratio + max_depo_peak_to_repo_peak_ratio = 1.5 + # maximum depo peak to trough ratio + max_depo_peak_to_trough_ratio = 0.5 + + for t, template in enumerate(templates): + # get absolute maximum + template_max = np.max(np.abs(template)) + # minimum prominence of the peak + prominence = heigh_ratio_to_trough * template_max + # get trough index + trough_idx = np.argmin(template) + trough_height = np.abs(template)[trough_idx] + + # get positive peaks + peak_idx, peak_properties = find_peaks( + x=template, + prominence=prominence, + width=min_width, + ) + + assert len(peak_idx) <= 2, "why are there more than 2 peaks?" + + if not len(peak_idx) == 0: + # maximum positive peak is not larger than 80% of trough + if not np.abs(template[peak_idx][-1])\ + / trough_height < max_repo_peak_to_trough_ratio: + print( + f"> {unit_ids[t]} has positive peak larger than 80% trough" + ) + continue + + if len(peak_idx) > 1: + # if both peaks before or after trough, consider non somatic + if (peak_idx > trough_idx).all()\ + or (peak_idx < trough_idx).all(): + print(f"> {unit_ids[t]} both peaks on one side") + continue + + # check compare bases + # if pre depolarisation peak + assert peak_properties["right_bases"][0] == trough_idx + # if post repolarisation peak + assert peak_properties["left_bases"][1] == trough_idx + + peak_heights = template[peak_idx] + + # check depo peak height is + if not peak_heights[0] / trough_height\ + < max_depo_peak_to_trough_ratio: + print(f"> {unit_ids[t]} depo peak is half the size of trough") + continue + + # check height ratio of peaks, make sure the depolarisation peak is + # NOT much bigger than repolarisation peak + if not peak_heights[0] / peak_heights[1]\ + < max_depo_peak_to_repo_peak_ratio: + print( + f"> {unit_ids[t]} depo peak is 1.5 times larger than " + "the repo" + ) + continue + + # yes somatic if no positive peaks or pass all checks + mask[t] = True + + return mask diff --git a/pixels/signal.py b/pixels/signal_utils.py similarity index 55% rename from pixels/signal.py rename to pixels/signal_utils.py index 6451c92..036621f 100644 --- a/pixels/signal.py +++ b/pixels/signal_utils.py @@ -6,16 +6,53 @@ import time from pathlib import Path +import multiprocessing as mp +from joblib import Parallel, delayed + import cv2 import numpy as np import matplotlib.pyplot as plt import pandas as pd import scipy.signal -from scipy.ndimage import gaussian_filter1d +from scipy.signal.windows import gaussian +from scipy.ndimage import gaussian_filter1d, convolve1d from pixels import ioutils, PixelsError +def decimate(array, from_hz, to_hz, ftype="fir"): + """ + Downsample the signal after applying an anti-aliasing filter. + Downsampling factor MUST be an integer, if not, call `resample`. + + Params + === + array : ndarray, Series or similar + The data to be resampled. + + from_hz : int or float + The starting frequency of the data. + + sample_rate : int or float, optional + The resulting sample rate. + + ftype: str or dlit instance + low pass filter type. + Default: fir (finite impulse response) filter + """ + if from_hz % to_hz == 0: + factor = from_hz // to_hz + output = scipy.signal.decimate( + x=array, + q=factor, + ftype=ftype, + ) + else: + output = resample(array, from_hz, to_hz) + + return output + + def resample(array, from_hz, to_hz, poly=True, padtype=None): """ Resample an array from one sampling rate to another. @@ -33,7 +70,7 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): poly : bool, choose the resample function. If True, use scipy.signal.resample_poly; if False, use scipy.signal.resample. - Default is False. + Default is True. lfp downsampling only works if using scipy.signal.resample. Returns @@ -76,6 +113,28 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): chunks = int(np.ceil(size_bytes / 5368709120)) chunk_size = int(np.ceil(cols / chunks)) + # get index & chunk data + #chunk_indices = [(i, min(i + chunk_size, cols)) for i in range(0, cols, chunk_size)] + #chunks_data = [array[:, start:end] for start, end in chunk_indices] + # get number of processes/jobs + #n_processes = mp.cpu_count() - 2 + # initiate a mp pool + #pool = mp.Pool(n_processes) + ## does resample for each chunk + #results = pool.starmap( + # _resample_chunk, + # [(chunk, up, down, poly, padtype) for chunk in chunks_data], + #) + + ## stop adding task to pool + #pool.close() + ## wait till all tasks in pool completed + #pool.join() + #results = Parallel(n_jobs=-1)( + # delayed(_resample_chunk)(chunk, up, down, poly, padtype) for chunk in chunks_data + #) + #print(">> mapped chunk data to pool...") + if chunks > 1: print(f" 0%", end="\r") current = 0 @@ -100,10 +159,26 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): new_data.append(result) current += chunk_size print(f" {100 * current / cols:.1f}%", end="\r") + #new_data = np.concatenate(results, axis=1).squeeze() - return np.concatenate(new_data, axis=1) #.astype(np.int16) + return np.concatenate(new_data, axis=1).squeeze()#.astype(np.int16) + #return new_data +def _resample_chunk(chunk_data, up, down, poly, padtype): + if poly: + result = scipy.signal.resample_poly( + chunk_data, up, down, axis=0, padtype=padtype or 'minimum', + ) + else: + samp_num = int(np.ceil( + chunk_data.shape[0] * (up / down) + )) + result = scipy.signal.resample( + chunk_data, samp_num, axis=0 + ) + return result + def binarise(data): """ This normalises an array to between 0 and 1 and then makes all values below 0.5 @@ -134,13 +209,13 @@ def _binarise_real(data): def find_sync_lag(array1, array2, plot=False): """ Find the lag between two arrays where they have the greatest number of the same - values. This functions assumes that the lag is less than 120,000 points. + values. This functions assumes that the lag is less than 300,000 points. Parameters ---------- array1 : array, Series or similar - The first array. A positive result indicates that this array has leading data - not present in the second array. e.g. if lag == 5 then array2 starts on the 5th + The first array. A positive result indicates that THIS ARRAY HAS LEADING DATA + NOT PRESENT IN THE SECOND ARRAY. e.g. if lag == 5 then array2 starts on the 5th index of array1. array2 : array, Series or similar @@ -165,33 +240,39 @@ def find_sync_lag(array1, array2, plot=False): array1 = array1.squeeze() array2 = array2.squeeze() - sync_p = [] + sync_pos = [] for i in range(length): + # finds how many values are the same in array1 as in array2 till given length matches = np.count_nonzero(array1[i:i + length] == array2[:length]) - sync_p.append(100 * matches / length) - match_p = max(sync_p) - lag_p = sync_p.index(match_p) - - sync_n = [] + # append the percentage of match given length + sync_pos.append(100 * matches / length) + # take the highest percentage during checks as the match + match_pos = max(sync_pos) + # find index where lag started in array1 + lag_pos = sync_pos.index(match_pos) + + sync_neg = [] for i in range(length): + # finds how many values are the same in array2 as in array1 till given length matches = np.count_nonzero(array2[i:i + length] == array1[:length]) - sync_n.append(100 * matches / length) - match_n = max(sync_n) - lag_n = sync_n.index(match_n) - - if match_p > match_n: - lag = lag_p - match = match_p + # append the percentage of match given length + sync_neg.append(100 * matches / length) + match_neg = max(sync_neg) + lag_neg = sync_neg.index(match_neg) + + if match_pos > match_neg: + lag = lag_pos + match = match_pos else: - lag = - lag_n - match = match_n + lag = - lag_neg + match = match_neg if plot: plot = Path(plot) if plot.exists(): plot = plot.with_name(plot.stem + '_' + time.strftime('%y%m%d-%H%M%S') + '.png') fig, axes = plt.subplots(nrows=2, ncols=1) - plot_length = min(length, 5000) + plot_length = min(length, 30000) if lag >= 0: axes[0].plot(array1[lag:lag + plot_length]) axes[1].plot(array2[:plot_length]) @@ -220,6 +301,80 @@ def median_subtraction(data, axis=0): return data - np.median(data, axis=axis, keepdims=True) +def _convolve_worker(shm_kernal, shm_times, sample_rate): + # TODO sep 22 2025: + # CONTINUE HERE! + # attach to shared memory + shm = shared_memory.SharedMemory(name=shm_kernal) + + convolved = convolve1d( + input=times.values, + weights=n_kernel, + output=np.float32, + mode="nearest", + axis=0, + ) * sample_rate # rescale it to second + + return None + + +def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): + """ + Convolve spike times data with 1D gaussian kernel to get spike rate. + + Parameters + ------- + times : pandas.DataFrame, time x units + Spike bool of units at each time point of a trial. + Dtype needs to be float, otherwise convolved results will be all 0. + + sigma : float/int, optional + Time in milliseconds of sigma of gaussian kernel to use. + Default: 100 ms. + + size : float/int, optional + Number of sigma for gaussian kernel to cover, i.e., size of the kernel + Default: 10. + + """ + # get kernel size in ms + kernel_size = int(sigma * size) + # get gaussian kernel + kernel = gaussian(kernel_size, std=sigma) + # normalise kernel to ensure that the total area under the Gaussian is 1 + n_kernel = kernel / np.sum(kernel) + + # TODO sept 19 2025: + # implement multiprocessing? + if isinstance(times, pd.DataFrame): + # convolve with gaussian + convolved = convolve1d( + input=times.values, + weights=n_kernel, + output=np.float32, + mode='nearest', + axis=0, + ) * sample_rate # rescale it to second + + output = pd.DataFrame( + convolved, + columns=times.columns, + index=times.index, + ) + + elif isinstance(times, np.ndarray): + # convolve with gaussian + output = convolve1d( + input=times, + weights=n_kernel, + output=np.float32, + mode='nearest', + axis=0, + ) * sample_rate # rescale it to second + + return output + + def convolve(times, duration, sigma=None): """ Create a continuous signal from a set of spike times in milliseconds and convolve @@ -242,10 +397,10 @@ def convolve(times, duration, sigma=None): sigma = 50 # turn into array of 0s and 1s - times_arr = np.zeros((int(duration), len(times.columns))) + times_arr = np.zeros((duration.astype(int), len(times.columns))) for i, unit in enumerate(times): u_times = times[unit] + duration / 2 - u_times = u_times[~np.isnan(u_times)].astype(np.int) + u_times = u_times[~np.isnan(u_times)].astype(int) try: times_arr[u_times, i] = 1 except IndexError: @@ -301,3 +456,38 @@ def motion_index(video, rois): mi = mi / mi.max(axis=0) return mi + + +def freq_notch(x, fs, w0, axis=0, bw=4.0): + """ + Use a notch filter that is a band-stop filter with a narrow bandwidth + between 48 to 52Hz. + It rejects a narrow frequency band and leaves the rest of the spectrum + little changed. + + params + === + x: array like, data to filter. + + fs: float or int, sampling frequency of x. + + w0: float or int, target frequency to notch (Hz). + + axis: int, axis of x to apply filter. + + bw: float or int, bandwidth of notch filter. + + return + === + notched: array like, notch filtered x. + """ + # convert to float + x = x.astype(np.float32, copy=False) + # set quality factor + Q = w0 / bw + # get numerator b & denominator a of IIR filter + b, a = scipy.signal.iirnotch(w0, Q, fs=fs) + # apply digital filter forward and backward + notched = scipy.signal.filtfilt(b, a, x, axis=axis) + + return notched diff --git a/pixels/stream.py b/pixels/stream.py new file mode 100644 index 0000000..473d664 --- /dev/null +++ b/pixels/stream.py @@ -0,0 +1,1353 @@ +# annotations not evaluated at runtime +from __future__ import annotations + +import gc +from shutil import copyfile + +import numpy as np +import pandas as pd + +import spikeinterface as si + +from pixels import ioutils +from pixels import pixels_utils as xut +import pixels.signal_utils as signal +from pixels.configs import * +from pixels.constants import * +from pixels.decorators import cacheable +from pixels.error import PixelsError + +from common_utils import file_utils + +class Stream: + def __init__( + self, + stream_id, + stream_num, + files, + session, + ): + self.stream_id = stream_id + self.stream_num = stream_num + self.probe_id = stream_id[:-3] + self.files = files + + self.session = session + self.behaviour_files = session.files["behaviour"] + self.BEHAVIOUR_SAMPLE_RATE = session.SAMPLE_RATE + self.raw = session.raw + self.interim = session.interim + self.cache = self.interim / "cache/" + self.processed = session.processed + self.histology = session.histology + + self._use_cache = True + + def __repr__(self): + return f"" + + + def load_raw_ap(self, copy=False): + paths = [ + self.session.find_file(path, copy=copy) + for path in self.files["ap_raw"] + ] + self.files["si_rec"] = xut.load_raw(paths, self.stream_id) + + return self.files["si_rec"] + + + def _map_trials(self, label, event, end_event=None): + # get synched pixels stream with vr and action labels + synched_vr, action_labels = self.get_synched_vr() + + # get action and event label file + outcomes = action_labels["outcome"] + events = action_labels["events"] + # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to + # convert it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE + timestamps = action_labels["timestamps"] + + # select frames of wanted trial type + trials = np.flatnonzero(outcomes & label) + # map starts by event + starts = np.flatnonzero(events & event) + # map starts by end event + ends = np.flatnonzero(events & end_event) + + if ("dark" in label.name) and\ + any(name in event.name for name in ["landmark", "wall"]): + # get dark onset and offset edges + dark_on = ((events & event.dark_on) != 0).astype(np.int8) + dark_off = ((events & event.dark_off) != 0).astype(np.int8) + # make in dark boolean + in_dark = np.cumsum(dark_on - dark_off) > 0 + # get starts only when also in dark + starts = np.flatnonzero(((events & event) != 0) & in_dark) + + # only take starts and ends from selected trials + selected_starts = trials[np.isin(trials, starts)] + selected_ends = trials[np.isin(trials, ends)] + + # make sure trials have both starts and ends, some trials ended before + # the end_event, and some dark trials are not in dark at start event + start_ids = synched_vr.iloc[selected_starts].trial_count.unique() + end_ids = synched_vr.iloc[selected_ends].trial_count.unique() + common_ids = np.intersect1d(start_ids, end_ids) + if len(start_ids) != len(end_ids): + selected_starts = selected_starts[np.isin(start_ids, common_ids)] + selected_ends = selected_ends[np.isin(end_ids, common_ids)] + + # get timestamps + start_t = timestamps[selected_starts] + end_t = timestamps[selected_ends] + + # use original trial ids as trial index + trial_ids = pd.Index(common_ids) + + return trials, events, selected_starts, start_t, end_t, trial_ids + + + #@cacheable(cache_format="zarr") + def _get_vr_positions(self, label, event, end_event): + logging.info( + f"\n> Getting {self.session.name} {self.stream_id} positions." + ) + + # map trials + (trials, events, selected_starts, + start_t, end_t, trial_ids) = self._map_trials( + label, + event, + end_event, + ) + + if selected_starts.size == 0: + logging.info(f"\n> No trials found with label {label} and event " + f"{event.name}, output will be empty.") + return None + + # get synched vr + synched_vr, _ = self.get_synched_vr() + + # get positions of all trials + all_pos = synched_vr.position_in_tunnel + all_pos_val = all_pos.to_numpy() + all_pos_idx = all_pos.index.to_numpy() + + # find start and end position index + trial_start_t = np.searchsorted(all_pos_idx, start_t, side="left") + trial_end_t = np.searchsorted(all_pos_idx, end_t, side="right") + trials_positions = [ + pd.Series(all_pos_val[s:e]) + for s, e in zip(trial_start_t, trial_end_t) + ] + positions = pd.concat(trials_positions, axis=1) + + # map actual starting locations + if not "trial_start" in event.name: + all_start_idx = np.flatnonzero(events & event.trial_start) + trial_start_idx = trials[np.isin(trials, all_start_idx)] + # make sure to only get the included trials' starting positions, + # i.e., the one with start and end event, not all trials in the + # label by aligning trial ids + all_ids = synched_vr.trial_count.iloc[trial_start_idx].values + start_idx = trial_start_idx[np.isin(all_ids, trial_ids)] + else: + start_idx = selected_starts.copy() + + # get start positions + start_pos = all_pos_val[start_idx].astype(int) + # create multiindex with starts + cols_with_starts = pd.MultiIndex.from_arrays( + [start_pos, trial_ids], + names=("start", "trial"), + ) + + # add level with start positions + positions.columns = cols_with_starts + positions = positions.sort_index(axis=1, ascending=[False, True]) + positions.index.name = "time" + + return positions + + + #@cacheable(cache_format="zarr") + def _get_vr_spikes(self, units, label, event, sigma, end_event): + logging.info( + f"\n> Getting {self.session.name} {self.stream_id} spikes." + ) + + # map trials timestamps and index + (_, _, selected_starts, + start_t, end_t, trial_ids) = self._map_trials( + label, + event, + end_event, + ) + + if selected_starts.size == 0: + logging.info(f"\n> No trials found with label {label} and event " + f"{event.name}, output will be empty.") + return None + + # get spike times + spikes = self.get_spike_times(units) + units = units[self.stream_id] + + # pad ends with 1 second extra to remove edge effects from + # convolution + scan_pad = self.BEHAVIOUR_SAMPLE_RATE + scan_starts = start_t - scan_pad + scan_ends = end_t + scan_pad + 1 + scan_durations = scan_ends - scan_starts + + cursor = 0 + raw_rec = self.load_raw_ap() + samples = raw_rec.get_total_samples() + # Account for multiple raw data files + in_SAMPLE_RATE_scale = (samples * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + cursor_duration = (cursor * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + rec_spikes = spikes[ + (cursor_duration <= spikes)\ + & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) + ] - cursor_duration + cursor += samples + + trials_spiked = {} + trials_fr = {} + for i, start in enumerate(selected_starts): + # select spike times of current trial + trial_bool = (rec_spikes >= scan_starts[i])\ + & (rec_spikes <= scan_ends[i]) + trial = rec_spikes[trial_bool] + + # initiate binary spike times array for current trial + # NOTE: dtype must be float otherwise would get all 0 when passing + # gaussian kernel + times = np.zeros((scan_durations[i], len(units)), dtype=np.float32) + # use pixels time as spike index + idx = np.arange(scan_starts[i], scan_ends[i]) + + for j, unit in enumerate(trial): + # get spike time for unit + u_times = trial[unit].values + # drop nan + u_times = u_times[~np.isnan(u_times)] + # round spike times to use it as index + u_spike_idx = np.round(u_times).astype(int) - scan_starts[i] + # make sure it does not exceed scan duration + u_spike_idx = u_spike_idx[ + (u_spike_idx >= 0) & (u_spike_idx < scan_durations[i]) + ] + if u_spike_idx.size: + # set spiked to 1 + times[np.unique(u_spike_idx), j] = 1 + + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) + + # convolve spike trains into spike rates + rates = signal.convolve_spike_trains( + times=spiked, + sigma=sigma, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + ) + + # remove 1s padding from the start and end + rates = rates.iloc[scan_pad: -scan_pad] + spiked = spiked.iloc[scan_pad: -scan_pad] + + # reset index to zero at the beginning of the trial + rates.reset_index(inplace=True, drop=True) + trials_fr[trial_ids[i]] = rates + spiked.reset_index(inplace=True, drop=True) + trials_spiked[trial_ids[i]] = spiked + + # get trials vertically stacked spiked + stacked_spiked = pd.concat( + trials_spiked, + axis=0, + ) + stacked_spiked.index.names = ["trial", "time"] + stacked_spiked.columns.names = ["unit"] + + output = {} + # get trials horizontally stacked spiked + spiked = ioutils.reindex_by_longest( + dfs=stacked_spiked, + level="trial", + return_format="dataframe", + ) + fr = ioutils.reindex_by_longest( + dfs=trials_fr, + level="trial", + idx_names=["trial", "time"], + col_names=["unit"], + return_format="dataframe", + ) + + output["spiked"] = spiked + output["fr"] = fr + + return stacked_spiked, output + + + @cacheable#(cache_format="zarr") + def align_trials(self, units, data, label, event, sigma, end_event): + """ + Align pixels data to behaviour trials. + + params + === + units : dictionary of lists of ints + The output from self.select_units, used to only apply this method to a + selection of units. + + data : str, optional + The data type to align. + + label : int + An action label value to specify which trial types are desired. + + event : int + An event type value to specify which event to align the trials to. + + sigma : int, optional + Time in milliseconds of sigma of gaussian kernel to use when + aligning firing rates. + + end_event : int | None + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. + + return + === + df, output from individual functions according to data type. + """ + + if "spike_trial" in data: + logging.info( + f"\n> Aligning spike times and spike rate of {units} units to " + f"<{label.name}> trials, from {event.name} to {end_event.name}." + ) + return self._get_aligned_trials( + label, event, units=units, sigma=sigma, end_event=end_event, + ) + elif "spike_event" in data: + logging.info( + f"\n> Aligning spike times and spike rate of {units} units to " + f"{event.name} event in <{label.name}> trials." + ) + return self._get_aligned_events( + label, event, units=units, sigma=sigma, + ) + else: + raise NotImplementedError( + "> Other types of alignment are not implemented." + ) + + + def _get_aligned_trials( + self, label, event, units=None, sigma=None, end_event=None, + ): + # get positions + positions = self._get_vr_positions(label, event, end_event) + + # get spikes and firing rate + _, output = self._get_vr_spikes( + units, + label, + event, + sigma, + end_event, + ) + output["positions"] = positions + + return output + + + def _get_aligned_events(self, label, event, units=None, sigma=None): + # TODO oct 17 2025: + # use _get_vr_spikes?? + # get spikes and firing rate + _, output = self._get_vr_spikes( + units, + label, + event, + sigma, + end_event, + ) + + # get synched pixels stream with vr and action labels + synched_vr, action_labels = self.get_synched_vr() + + # get positions of all trials + all_pos = synched_vr.position_in_tunnel + + # get spike times + spikes = self.get_spike_times(units) + # now get array unit ids + units = units[self.stream_id] + + # get action and event label file + outcomes = action_labels["outcome"] + events = action_labels["events"] + # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to + # convert it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE + timestamps = action_labels["timestamps"] + + # select frames of wanted trial type + trials = np.where(np.bitwise_and(outcomes, label))[0] + # map starts by event + starts = np.where(np.bitwise_and(events, event))[0] + + # only take starts from selected trials + selected_starts = trials[np.where(np.isin(trials, starts))[0]] + start_t = timestamps[selected_starts] + + if selected_starts.size == 0: + logging.info(f"\n> No trials found with label {label} and event " + f"{event.name}, output will be empty.") + return None + + # use original trial id as trial index + trial_ids = pd.Index( + synched_vr.iloc[selected_starts].trial_count.unique() + ) + + # TODO aug 1 2025: + # lick happens more than once in a trial, thus i here does not + # correspond to trial index, fit it + # check if event happens more than once in each trial + if start_t.size > trial_ids.size: + trial_counts = synched_vr.loc[start_t, "trial_count"] + + # map actual starting locations + if not "trial_start" in event.name: + all_start_idx = np.where( + np.bitwise_and(events, event.trial_start) + )[0] + start_idx = trials[np.where( + np.isin(trials, all_start_idx) + )[0]] + else: + start_idx = selected_starts.copy() + + start_pos = synched_vr.position_in_tunnel.iloc[ + start_idx + ].values.astype(int) + + # map starting position with trial + start_trial_maps = dict(zip(trial_ids, start_pos)) + + # pad ends with 1 second extra to remove edge effects from convolution, + # during of event is 2s (pre + post) + duration = 1 + pad_duration = 1 + scan_pad = self.BEHAVIOUR_SAMPLE_RATE + one_side_frames = scan_pad * (duration + pad_duration) + scan_starts = start_t - one_side_frames + scan_ends = start_t + one_side_frames + 1 + scan_duration = one_side_frames * 2 + 1 + relative_idx = np.linspace( + -(duration+pad_duration), + (duration+pad_duration), + scan_duration, + ) + + cursor = 0 + raw_rec = self.load_raw_ap() + samples = raw_rec.get_total_samples() + # Account for multiple raw data files + in_SAMPLE_RATE_scale = (samples * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + cursor_duration = (cursor * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + rec_spikes = spikes[ + (cursor_duration <= spikes)\ + & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) + ] - cursor_duration + cursor += samples + + output = {} + trials_fr = {} + trials_spiked = {} + trials_positions = {} + for i, start in enumerate(selected_starts): + assert 0 + # TODO sep 16 2025: + # make sure it does not fail for event that happens more than once, + # e.g., licks + # select spike times of event in current trial + trial_bool = (rec_spikes >= scan_starts[i])\ + & (rec_spikes <= scan_ends[i]) + trial = rec_spikes[trial_bool] + + # initiate binary spike times array for current trial + # NOTE: dtype must be float otherwise would get all 0 when passing + # gaussian kernel + times = np.zeros((scan_duration, len(units))).astype(float) + # use pixels time as spike index + idx = np.arange(scan_starts[i], scan_ends[i]) + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) + + for unit in trial: + # get spike time for unit + u_times = trial[unit].values + # drop nan + u_times = u_times[~np.isnan(u_times)] + # round spike times to use it as index + u_spike_idx = np.round(u_times).astype(int) + # make sure it does not exceed scan duration + if (u_spike_idx >= scan_ends[i]).any(): + beyonds = np.where(u_spike_idx >= scan_ends[i])[0] + u_spike_idx[beyonds] = idx[-1] + # make sure no double counted + u_spike_idx = np.unique(u_spike_idx) + + # set spiked to 1 + spiked.loc[u_spike_idx, unit] = 1 + + # set spiked index to relative index + spiked.index = relative_idx + # convolve spike trains into spike rates + rates = signal.convolve_spike_trains( + times=spiked, + sigma=sigma, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + ) + + # remove 1s padding from the start and end + rates = rates.iloc[scan_pad: -scan_pad] + spiked = spiked.iloc[scan_pad: -scan_pad] + + trials_fr[trial_ids[i]] = rates + trials_spiked[trial_ids[i]] = spiked + + # get trials vertically stacked spiked + stacked_spiked = pd.concat( + trials_spiked, + axis=0, + ) + stacked_spiked.index.names = ["trial", "time"] + stacked_spiked.columns.names = ["unit"] + + # get trials horizontally stacked spiked + spiked = ioutils.reindex_by_longest( + dfs=stacked_spiked, + level="trial", + return_format="dataframe", + ) + trial_cols = spiked.columns.get_level_values("trial") + start_cols = trial_cols.map(start_trial_maps) + unit_cols = spiked.columns.get_level_values("unit") + new_cols = pd.MultiIndex.from_arrays( + [unit_cols, start_cols, trial_cols], + names=["unit", "start", "trial"] + ) + spiked.columns = new_cols + spiked = spiked.sort_index( + axis=1, + level=["unit", "start", "trial"], + ascending=[True, False, True], + ) + + fr = ioutils.reindex_by_longest( + dfs=trials_fr, + level="trial", + idx_names=["trial", "time"], + col_names=["unit"], + return_format="dataframe", + ) + fr.columns = new_cols + fr = fr.loc[:, spiked.columns] + + output["spiked"] = spiked + output["fr"] = fr + + return output + + + def get_spike_times(self, units): + units = units[self.stream_id] + + # find sorting analyser, use merged if there is one + merged_sa_path = self.session.find_file( + self.files["merged_sorting_analyser"] + ) + if merged_sa_path: + sa_path = merged_sa_path + else: + sa_path = self.session.find_file(self.files["sorting_analyser"]) + # load sorting analyser + temp_sa = si.load_sorting_analyzer(sa_path) + # select units + sorting = temp_sa.sorting.select_units(units) + sa = temp_sa.select_units(units) + sa.sorting = sorting + + times = {} + # get spike train + for i, unit_id in enumerate(sa.unit_ids): + unit_times = sa.sorting.get_unit_spike_train( + unit_id=unit_id, + return_times=False, + ) + times[unit_id] = pd.Series(unit_times) + + # concatenate units + spike_times = pd.concat( + objs=times, + axis=1, + names="unit", + ) + # get sampling frequency + fs = int(sa.sampling_frequency) + # Convert to time into sample rate index + spike_times /= fs / self.BEHAVIOUR_SAMPLE_RATE + + return spike_times + + + def sync_vr(self, vr_session): + # get action labels & synched vr path + action_labels = self.session.get_action_labels()[self.stream_num] + synched_vr_path = self.session.find_file( + self.behaviour_files["vr_synched"][self.stream_num], + ) + if action_labels and synched_vr_path: + logging.info(f"\n> {self.stream_id} from {self.session.name} is " + "already synched with vr.") + else: + self._sync_vr(vr_session) + + return None + + + def _sync_vr(self, vr_session): + # get synchronised vr path + synched_vr_path = vr_session.cache_dir + "synched/" +\ + vr_session.name + "_vr_synched.h5" + + try: + synched_vr = file_utils.read_hdf5(synched_vr_path) + logging.info("\n> synchronised vr loaded") + except: + # get spike data + spike_data = self.session.find_file( + name=self.files["ap_raw"][self.stream_num], + copy=True, + ) + + # get sync pulses + sync_map = ioutils.read_bin(spike_data, 385, 384) + syncs = signal.binarise(sync_map) + + # >>>> resample pixels sync pulse to sample rate >>>> + # get ap data sampling rate + spike_samp_rate = int(self.session.ap_meta[0]['imSampRate']) + # downsample pixels sync pulse + downsampled = signal.decimate( + array=syncs, + from_hz=spike_samp_rate, + to_hz=self.BEHAVIOUR_SAMPLE_RATE, + ) + # binarise to avoid non integers + pixels_syncs = signal.binarise(downsampled) + # <<<< resample pixels sync pulse to 1kHz <<<< + + # TODO apr 11 2025: + # for 20250723 VDCN09, sync pulses are weird. check number of syncs + # from 145s onwards, see if it matches with vr frames from 100s + # onwards + + # get the rise and fall edges in pixels sync + pixels_edges = np.where(np.diff(pixels_syncs) != 0)[0] + 1 + # double check if the pulse from arduino initiation is also + # included, if so there will be two long pulses before vr frames + first_pulses = np.diff(pixels_edges)[:4] + if (first_pulses > 1000).all(): + logging.info("\n> There are two long pulses before vr frames, " + "remove both.") + remove = 4 + else: + remove = 2 + pixels_vr_edges = pixels_edges[remove:] + # convert value into their index to calculate all timestamps + pixels_idx = np.arange(pixels_syncs.shape[0]) + + synched_vr = vr_session.sync_streams( + self.BEHAVIOUR_SAMPLE_RATE, + pixels_vr_edges, + pixels_idx, + ) + + synched_vr_file = self.behaviour_files["vr_synched"][self.stream_num] + try: + assert self.session.find_file(synched_vr_file) + except: + file_utils.write_hdf5( + self.processed / synched_vr_file, + synched_vr, + ) + + # get action label dir + action_labels_path = self.processed /\ + self.behaviour_files["action_labels"][self.stream_num] + + # extract and save action labels + action_labels = self.session._extract_action_labels( + vr_session, + synched_vr, + ) + labels_dict = action_labels._asdict() + np.savez_compressed( + action_labels_path, + **labels_dict, + ) + logging.info(f"\n> Action labels saved to: {action_labels_path}.") + + if hasattr(self.session, "backup"): + # copy to backup if backup setup + copyfile( + self.processed / synched_vr_file, + self.session.backup / synched_vr_file, + ) + copyfile( + action_labels_path, + self.session.backup / action_labels_path.name, + ) + logging.info( + "\n> Syched vr and action labels copied to: " + f"{self.session.backup}." + ) + + return None + + + def get_synched_vr(self): + """ + Get synchronised vr data and action labels. + """ + action_labels = self.session.get_action_labels()[self.stream_num] + + synched_vr_path = self.session.find_file( + self.behaviour_files["vr_synched"][self.stream_num], + ) + synched_vr = file_utils.read_hdf5(synched_vr_path) + + return synched_vr, action_labels + + + @cacheable + def get_binned_trials( + self, label, event, units=None, sigma=None, end_event=None, + time_bin=None, pos_bin=None + ): + # define output path for binned spike rate + output_path = self.cache/ f"{self.session.name}_{units}_{label.name}_"\ + f"{time_bin}_{pos_bin}cm_{self.stream_id}.npz" + binned = self._bin_aligned_trials( + label=label, + event=event, + units=units, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + output_path=output_path, + ) + + return binned + + + def _bin_aligned_trials( + self, label, event, units, sigma, end_event, time_bin, pos_bin, + output_path, + ): + # get aligned trials + trials = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="spike_trial", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + ) + + if trials is None: + logging.info(f"\n> No trials found with label {label.name} and " + f"event {event.name}, output will be empty.") + return None + + logging.info( + f"\n> Binning <{label.name}> trials from {self.stream_id} " + f"in {units}." + ) + + # get fr, spiked, positions + fr = trials["fr"] + spiked = trials["spiked"] + positions = trials["positions"] + + # TODO apr 11 2025: + # bin chance while bin data + #spiked_chance_path = self.processed / stream_files["spiked_shuffled"] + #spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") + #bin_counts_chance[stream_id] = {} + + bin_arr = {} + binned_count = {} + binned_fr = {} + + trial_ids = positions.columns.get_level_values("trial").unique() + for trial in trial_ids: + counts = spiked.xs(trial, level="trial", axis=1).dropna() + rates = fr.xs(trial, level="trial", axis=1).dropna() + trial_pos = positions.xs(trial, level="trial", axis=1).dropna() + + # get bin spike count + binned_count[trial] = xut.bin_vr_trial( + data=counts, + positions=trial_pos, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", + ) + # get bin firing rates + binned_fr[trial] = xut.bin_vr_trial( + data=rates, + positions=trial_pos, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", + ) + + # stack df values into np array + # reshape into trials x units x bins + count_arr = ioutils.reindex_by_longest(binned_count).T + fr_arr = ioutils.reindex_by_longest(binned_fr).T + + # save bin_fr and bin_count, for andrew + # use label as array key name + bin_arr["count"] = count_arr[:, :-2, :] + bin_arr["fr"] = fr_arr[:, :-2, :] + bin_arr["pos"] = count_arr[:, -2:, :] + + np.savez_compressed(output_path, **bin_arr) + logging.info(f"\n> Output saved at {output_path}.") + + # extract binned data in df format + bin_fc, bin_pos = self._extract_binned_data( + binned_count, + positions.columns, + ) + bin_fr, _ = self._extract_binned_data( + binned_fr, + positions.columns, + ) + + # convert it to binned positional data + pos_data = xut.get_vr_positional_data( + { + "positions": bin_pos.bin_pos, + "fr": bin_fr, + "spiked": bin_fc, + }, + ) + + return pos_data + + + def _extract_binned_data(self, binned_data, pos_cols): + """ + """ + df = ioutils.reindex_by_longest( + dfs=binned_data, + idx_names=["trial", "time_bin"], + col_names=["unit"], + return_format="dataframe", + ) + data = df.drop( + labels=["positions", "bin_pos"], + axis=1, + level="unit", + ) + pos = df.filter( + like="pos", + axis=1, + ) + pos.columns.names = ["pos_type", "trial"] + + # convert columns to df + pos_col_df = pos.columns.to_frame(index=False) + start_trial_df = pos_cols.to_frame(index=False) + # merge columns + merged_cols = pd.merge(pos_col_df, start_trial_df, on="trial") + + # create new columns + new_cols = pd.MultiIndex.from_frame( + merged_cols[["pos_type", "start", "trial"]], + names=["pos_type", "start", "trial"], + ) + pos.columns = new_cols + + return data, pos + + + #@cacheable + def get_positional_data( + self, label, event, end_event=None, sigma=None, units=None, + normalised=False, + ): + """ + Get positional firing rate of selected units in vr, and spatial + occupancy of each position. + """ + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + # get aligned firing rates and positions + trials = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="spike_trial", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + ) + + if normalised: + grays = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="spike_trial", # NOTE: ALWAYS the second arg + label=getattr(label, label.name.split("_")[-1]), + event=event.gray_on, + sigma=sigma, + end_event=end_event.gray_off, # NOTE: ALWAYS the last arg + ) + + # NOTE july 24 2025: if get gray mu & sigma per trial we got z score + # of very quiet & stable in gray units >9e+15... thus, we normalise + # average all trials for each unit, rather than per trial + + # NOTE: 500ms of 500Hz sine wave sound at each trial start, 2000ms + # of gray, so only take the second 1000ms in gray to get mean and + # std + + # only select trials exists in aligned trials + baseline = grays["fr"].iloc[ + self.BEHAVIOUR_SAMPLE_RATE: self.BEHAVIOUR_SAMPLE_RATE * 2 + ].loc[:, trials["fr"].columns].T.groupby("unit").mean().T + + mu = baseline.mean() + centered = trials["fr"].sub(mu, axis=1, level="unit") + std = baseline.std() + z_fr = centered.div(std, axis=1, level="unit") + trials["fr"] = z_fr + + del grays, baseline, mu, std, z_fr + gc.collect() + + # get positional spike rate, spike count, and occupancy + positional_data = xut.get_vr_positional_data(trials) + + return positional_data + + + def preprocess_raw(self): + # load raw ap + raw_rec = self.load_raw_ap(copy=True) + + # load brain surface depths + depth_info = file_utils.load_yaml( + path=self.histology / self.files["depth_info"], + ) + surface_depths = depth_info["raw_signal_depths"][self.stream_id] + + # find faulty channels to remove + faulty_channels = file_utils.load_yaml( + path=self.processed / self.files["faulty_channels"], + ) + + # preprocess + self.files["preprocessed"] = xut.preprocess_raw( + raw_rec, + surface_depths, + faulty_channels, + ) + + return None + + + def extract_bands(self, freqs, preprocess=True): + if preprocess: + self.preprocess_raw() + rec = self.files["preprocessed"] + else: + rec = self.load_raw_ap(copy=True) + + if freqs == None: + bands = freq_bands + elif isinstance(freqs, str) and freqs in freq_bands.keys(): + bands = {freqs: freq_bands[freqs]} + elif isinstance(freqs, dict): + bands = freqs + + for name, freqs in bands.items(): + logging.info( + f"\n> Extracting {name} bands from {self.stream_id}." + ) + # do bandpass filtering + extracted = xut.extract_band( + rec, + freq_min=freqs[0], + freq_max=freqs[1], + ) + + logging.info( + f"\n> Common average referencing {name} band." + ) + self.files[f"{name}_extracted"] = xut.CAR(extracted) + + return None + + + def correct_ap_motion(self): + # get ap band + self.extract_bands("ap") + ap_rec = self.files["ap_extracted"] + + # correct ap motion + self.files["ap_motion_corrected"] = xut.correct_ap_motion(ap_rec) + + return None + + + def correct_lfp_motion(self): + raise NotImplementedError("> Not implemented.") + + + def whiten_ap(self): + # get motion corrected ap + mcd = self.files["ap_motion_corrected"] + + # whiten + self.files["ap_whitened"] = xut.whiten(mcd) + + return None + + + def sort_spikes(self, ks_mc, ks4_params, ks_image_path, output, sa_dir): + """ + Sort spikes of stream. + + params + === + ks_mc: bool, whether using kilosort 4 innate motion correction. + + ks4_params: dict, kilosort 4 parameters. + + output: path, directory to save sorting output. + + sa_dir: path, directory to save sorting analyser. + + return + === + None + """ + # use only preprocessed if use ks motion correction + if ks_mc: + self.preprocess_raw() + rec = self.files["preprocessed"] + sa_rec = self.files["ap_motion_corrected"] + else: + # XXX: as of may 2025, whiten ap band and feed to ks reduce units! + #rec = self.files["ap_whitened"] + # use non-whitened recording for ks4 and sorting analyser + sa_rec = rec = self.files["ap_motion_corrected"] + + # sort spikes and save sorting analyser to disk + xut.sort_spikes( + rec=rec, + sa_rec=sa_rec, + output=output, + curated_sa_dir=sa_dir, + ks_image_path=ks_image_path, + ks4_params=ks4_params, + ) + + return None + + + def save_spike_chance(self, spiked, sigma): + # TODO apr 21 2025: + # do we put this func here or in stream.py?? + # save index and columns to reconstruct df for shuffled data + assert 0 + ioutils.save_index_to_frame( + df=spiked, + path=shuffled_idx_path, + ) + ioutils.save_cols_to_frame( + df=spiked, + path=shuffled_col_path, + ) + + # get chance data paths + paths = { + "spiked_memmap_path": self.interim /\ + stream_files["spiked_shuffled_memmap"], + "fr_memmap_path": self.interim / stream_files["fr_shuffled_memmap"], + } + + # save chance data + xut.save_spike_chance( + **paths, + sigma=sigma, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + spiked=spiked, + ) + + return None + + + def get_spatial_psd( + self, label, event, end_event=None, sigma=None, units=None, + crop_from=None, use_binned=False, time_bin=None, pos_bin=None, + ): + """ + Get spatial power spectral density of selected units. + """ + # NOTE: jun 19 2025 + # potentially we could use aligned trials directly for psd estimation, + # with trial position as x, fr as y, and use lomb-scargle method + + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + # get aligned firing rates and positions + if not use_binned: + trials = self.get_positional_data( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + ) + crop_from = crop_from + else: + trials = self.get_binned_trials( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + ) + crop_from = crop_from // pos_bin + 1 + + # get positional fr + pos_fr = trials["pos_fr"] + + starts = pos_fr.columns.get_level_values("start").unique() + psds = {} + for s, start in enumerate(starts): + data = pos_fr.xs(start, level="start", axis=1) + # crop if needed + cropped = data.loc[crop_from:, :] + + # get power spectral density + psds[start] = xut.get_spatial_psd(cropped) + + psd_df = pd.concat( + psds, + names=["start","frequency"], + ) + # NOTE: all trials will appear in all starts, but their values will be + # all nan in other starts, so remember to dropna(axis=1)! + + return psd_df + + + @cacheable(cache_format="zarr") + def get_spike_chance(self, units, label, event, sigma, end_event, + # reserved kwargs injected by decorator when cache_format='zarr' + _zarr_out: Path | str | None = None, + ): + # get positions + positions = self._get_vr_positions(label, event, end_event) + + # get spikes + stacked_spikes, _ = self._get_vr_spikes( + units, + label, + event, + sigma, + end_event, + ) + + logging.info( + f"\n> Getting {self.session.name} {self.stream_id} spike chance." + ) + + xut.save_spike_chance_zarr( + zarr_path=_zarr_out, + spiked=stacked_spikes, + sigma=sigma, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + repeats=REPEATS, + positions=positions, + meta=dict( + label=str(label), + event=str(event), + end_event=str(end_event), + units_name=getattr(units, "name", None), + ), + ) + del stacked_spikes, positions + gc.collect() + + return None + + + @cacheable(cache_format="zarr") + def get_binned_chance( + self, units, label, event, sigma, end_event, time_bin, pos_bin, + ): + # get array name and path + name_parts = [self.session.name, self.probe_id, label.name, units.name, + "shuffled", time_bin, f"{pos_bin}cm.npz"] + file_name = "_".join(p for p in name_parts) + arr_path = self.processed / file_name + + # get chance data + chance_data = self.get_spike_chance( + units, + label, + event, + sigma, + end_event, + ) + # bin chance data + binned_chance = xut.bin_spike_chance( + chance_data, + self.BEHAVIOUR_SAMPLE_RATE, + time_bin, + pos_bin, + arr_path, + ) + + return binned_chance + + + def _get_chance_args(self, label, event, sigma, end_event): + probe_id = self.stream_id[:-3] + name = self.session.name + paths = { + "spiked_memmap_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_spiked_shuffled.bin", + "fr_memmap_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_fr_shuffled.bin", + "memmap_shape_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_shuffled_shape.json", + "idx_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_shuffled_index.h5", + # NOTE: if all units, all conditions share the same columns + "col_path": self.interim/\ + self.files["shuffled_columns"], + } + + return positions, paths + + + @cacheable + def get_chance_positional_psd(self, units, label, event, sigma, end_event): + from vision_in_darkness.constants import PRE_DARK_LEN, landmarks + positions, paths = self._get_chance_args( + units, + label, + event, + sigma, + end_event, + ) + + logging.info("> getting chance psd") + psds = xut.save_chance_psd(self.BEHAVIOUR_SAMPLE_RATE, positions, paths) + + return psds + + + @cacheable(cache_format="zarr") + def get_landmark_responsives(self, units, label, sigma): + + from vision_in_darkness.constants import landmarks, mid_walls + from pixels.behaviours.virtual_reality import Events + + wall_events = np.array([ + value for key, value in Events.__dict__.items() + if not (key.startswith("__") or key.startswith("_") + or callable(value) + or isinstance(value, (classmethod, staticmethod))) + and "wall" in key], + dtype=object, + ).reshape(-1, 2) + + # get start and end events, excluding the last landmark + start_events = wall_events[:, 0][:-1] + end_events = wall_events[:, 1][1:] + + # get on & off of landmark and walls and stack them + landmark_names = np.arange(1, len(start_events)+1) + lms = landmarks[1:-2].reshape((-1, 2)) + walls = mid_walls.reshape((-1, 2)) + pre_walls = walls[:-1, :] + post_walls = walls[1:, :] + + # group pre wall, landmarks, and post wall together + chunks = np.stack([pre_walls, lms, post_walls], axis=1) + + ons = chunks[..., 0] + offs = chunks[..., 1] + + all_contrasts = {} + resps = {} + for l, _ in enumerate(lms): + pos_fr = self.get_positional_data( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=start_events[l], + sigma=sigma, + end_event=end_events[l], # NOTE: ALWAYS the last arg + )["pos_fr"] + + lm = landmark_names[l] + all_contrasts[lm], resps[lm] = xut.get_landmark_responsives( + pos_fr=pos_fr, + units=units, + ons=ons[l, :], + offs=offs[l, :], + ) + + responsives = pd.concat(resps, axis=1, names="landmark") + + contrasts = pd.concat( + all_contrasts, + axis=0, + names=["landmark", "unit"], + ) + contrasts.columns.name = "metrics" + contrasts.start = contrasts.start.astype(int) + # so that row index is unique + contrasts = contrasts.set_index(["start", "contrast"], append=True) + + return {"contrasts": contrasts, "responsives": responsives} diff --git a/pixels/units.py b/pixels/units.py new file mode 100644 index 0000000..bcbb4d3 --- /dev/null +++ b/pixels/units.py @@ -0,0 +1,33 @@ +from typing import Iterable + +class SelectedUnits(dict[str, list[int]]): + name: str + """ + A mapping from stream_id to lists of unit IDs. + Behaves like a dict in every way, except that when represented as a string, + it can return a `name` if set. This allows named instances to be cached to file. + """ + + def __init__(self, *args, name: str | None = None, **kwargs): + super().__init__(*args, **kwargs) + if name is not None: + self.name = name + + def __repr__(self) -> str: + if hasattr(self, "name"): + return self.name + return dict.__repr__(self) + + # Convenience helpers + def add(self, stream_id: str, unit_id: int) -> None: + self.setdefault(stream_id, []).append(unit_id) + + def extend(self, stream_id: str, unit_ids: Iterable[int]) -> None: + self.setdefault(stream_id, []).extend(unit_ids) + + def flat(self) -> list[int]: + # If you sometimes need a flat list of all unit IDs + out: list[int] = [] + for ids in self.values(): + out.extend(ids) + return out