From 5161a184b190c38006e72f3b8ac5870f71fd2ae7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jun 2022 19:07:18 +0100 Subject: [PATCH 001/658] add Fourier method resample func to make lfp downsampling work --- pixels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/signal.py b/pixels/signal.py index 6451c92..a895de6 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -79,7 +79,7 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): if chunks > 1: print(f" 0%", end="\r") current = 0 - for _ in range(chunks): + for i in range(chunks): chunk_data = array[:, current:min(current + chunk_size, cols)] if poly: # matt's old poly func From c7d0b8c8847e066eff36982a5c6e23fd99432afa Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jun 2022 19:29:54 +0100 Subject: [PATCH 002/658] perform median subtraction and downsampling on lfp data and save output --- pixels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/signal.py b/pixels/signal.py index a895de6..6451c92 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -79,7 +79,7 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): if chunks > 1: print(f" 0%", end="\r") current = 0 - for i in range(chunks): + for _ in range(chunks): chunk_data = array[:, current:min(current + chunk_size, cols)] if poly: # matt's old poly func From e90887cd6cded9341eacd548f47cf8e34eb206be Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Jun 2022 18:38:39 +0100 Subject: [PATCH 003/658] make sure process_lfp does not use resample_poly --- pixels/behaviours/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index dc98056..d0a0a21 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -441,10 +441,11 @@ def process_lfp(self): sync_chan = downsampled[:, -1] downsampled = downsampled[:, :-1] + """ if self._lag[rec_num] is None: self.sync_data(rec_num, sync_channel=data[:, -1]) lag_start, lag_end = self._lag[rec_num] - + """ sd = self.processed / recording['lfp_sd'] if sd.exists(): continue @@ -460,10 +461,12 @@ def process_lfp(self): with open(sd, 'w') as fd: json.dump(results, fd) + """ if lag_end < 0: data = data[:lag_end] if lag_start < 0: data = data[- lag_start:] + """ print(f"> Saving median subtracted & downsampled LFP to {output}") # save in .npy format np.save( From 6d323e6f145c45ef0de6cc13baee3594640aa936 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 24 Jun 2022 23:18:13 +0100 Subject: [PATCH 004/658] sort CatGT-ed ap data if available --- pixels/behaviours/base.py | 34 +++++++++++++++++++++++++--------- pixels/ioutils.py | 2 ++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d0a0a21..c6a6d92 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -8,10 +8,12 @@ import functools import json import os +import glob import pickle import shutil import tarfile import tempfile +import re from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path @@ -139,6 +141,10 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): else: self.interim = Path(interim_dir) / self.name + self.catGT_dir = glob.glob( + str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' + ) + self.interim.mkdir(parents=True, exist_ok=True) self.processed.mkdir(parents=True, exist_ok=True) @@ -477,32 +483,42 @@ def process_lfp(self): #downsampled = pd.DataFrame(downsampled) #ioutils.write_hdf5(output, downsampled) + def sort_spikes(self): """ Run kilosort spike sorting on raw spike data. """ streams = {} - for rec_num, files in enumerate(self.files): - data_file = self.find_file(files['spike_data']) - assert data_file, f"Spike data not found for {files['spike_data']}." - - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: + for _, files in enumerate(self.files): + if len(self.catGT_dir) == 0: + print(f"> Spike data not found for {files['catGT_ap_data']},\ + \nuse the orignial spike data.") + data_file = self.find_file(files['spike_data']) metadata = self.find_file(files['spike_meta']) - streams[stream_id] = metadata + else: + self.catGT_dir = Path(self.catGT_dir[0]) + data_file = self.catGT_dir / files['catGT_ap_data'] + metadata = self.catGT_dir / files['catGT_ap_meta'] + + stream_id = data_file.as_posix()[-12:-4] + if stream_id not in streams: + streams[stream_id] = metadata for stream_num, stream in enumerate(streams.items()): stream_id, metadata = stream try: - recording = se.SpikeGLXRecordingExtractor(self.interim, stream_id=stream_id) + recording = se.SpikeGLXRecordingExtractor(self.catGT_dir, stream_id=stream_id) except ValueError as e: raise PixelsError( f"Did the raw data get fully copied to interim? Full error: {e}" ) print("> Running kilosort") - output = self.processed / f'sorted_stream_{stream_num}' + if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: + output = self.processed / f'sorted_stream_cat_{stream_num}' + else: + output = self.processed / f'sorted_stream_{stream_num}' concat_rec = si.concatenate_recordings([recording]) probe = pi.read_spikeglx(metadata.as_posix()) concat_rec = concat_rec.set_probe(probe) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 73b64e1..5b43884 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -117,6 +117,8 @@ def get_data_files(data_dir, session_name): recording['depth_info'] = recording['lfp_data'].with_name( f'depth_info_{num}.json' ) + recording['catGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") + recording['catGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") files.append(recording) return files From 759c63f2b1c56e086825bef5d5840d0e83bc2325 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Nov 2022 11:51:12 +0000 Subject: [PATCH 005/658] temporary add catgt thing; tempt to fix remove duplicate units --- pixels/behaviours/base.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c6a6d92..47b3a92 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -497,6 +497,7 @@ def sort_spikes(self): data_file = self.find_file(files['spike_data']) metadata = self.find_file(files['spike_meta']) else: + print("> Sorting catgt-ed spikes") self.catGT_dir = Path(self.catGT_dir[0]) data_file = self.catGT_dir / files['catGT_ap_data'] metadata = self.catGT_dir / files['catGT_ap_meta'] @@ -1118,14 +1119,19 @@ def get_lfp_data(self): """ return self._get_processed_data("_lfp_data", "lfp_processed") - def _get_spike_times(self): + def _get_spike_times(self, catgt=False): """ Returns the sorted spike times. """ saved = self._spike_times_data if saved[0] is None: - times = self.processed / f'sorted_stream_0' / 'spike_times.npy' - clust = self.processed / f'sorted_stream_0' / 'spike_clusters.npy' + # TODO: temporarily add catgt arg here, + if catgt: + stream = 'sorted_stream_cat_0' + else: + stream = 'sorted_stream_0' + times = self.processed / stream / f'spike_times.npy' + clust = self.processed / stream / f'spike_clusters.npy' try: times = np.load(times) @@ -1139,8 +1145,18 @@ def _get_spike_times(self): by_clust = {} for c in np.unique(clust): - by_clust[c] = pd.Series(times[clust == c]).drop_duplicates() + c_times = times[clust == c] + uniques, counts = np.unique( + c_times, + return_counts=True, + ) + repeats = c_times[np.where(counts>1)] + if len(repeats>1): + print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + + by_clust[c] = pd.Series(uniques) saved[0] = pd.concat(by_clust, axis=1, names=['unit']) + assert 0 return saved[0] def _get_aligned_spike_times( From 437cce97437eb6fcd1fa697f9d26f371362974f4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jun 2022 19:07:18 +0100 Subject: [PATCH 006/658] add Fourier method resample func to make lfp downsampling work --- pixels/signal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/signal.py b/pixels/signal.py index 6451c92..84ec940 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -16,7 +16,7 @@ from pixels import ioutils, PixelsError -def resample(array, from_hz, to_hz, poly=True, padtype=None): +def resample(array, from_hz, to_hz, poly=False, padtype=None): """ Resample an array from one sampling rate to another. @@ -79,7 +79,7 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): if chunks > 1: print(f" 0%", end="\r") current = 0 - for _ in range(chunks): + for i in range(chunks): chunk_data = array[:, current:min(current + chunk_size, cols)] if poly: # matt's old poly func From eff966184f17d31856e502da38667e2427c21864 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jun 2022 19:29:54 +0100 Subject: [PATCH 007/658] perform median subtraction and downsampling on lfp data and save output --- pixels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/signal.py b/pixels/signal.py index 84ec940..93ec9ae 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -79,7 +79,7 @@ def resample(array, from_hz, to_hz, poly=False, padtype=None): if chunks > 1: print(f" 0%", end="\r") current = 0 - for i in range(chunks): + for _ in range(chunks): chunk_data = array[:, current:min(current + chunk_size, cols)] if poly: # matt's old poly func From 86edfddb9a9d7823c7431f7f8fb34e62eac3c99b Mon Sep 17 00:00:00 2001 From: amz <77285087+arthurz323@users.noreply.github.com> Date: Fri, 17 Jun 2022 18:34:52 +0100 Subject: [PATCH 008/658] use resample_poly as default resample method --- pixels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/signal.py b/pixels/signal.py index 93ec9ae..6451c92 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -16,7 +16,7 @@ from pixels import ioutils, PixelsError -def resample(array, from_hz, to_hz, poly=False, padtype=None): +def resample(array, from_hz, to_hz, poly=True, padtype=None): """ Resample an array from one sampling rate to another. From 70a1400e826a8f1a429bb007e074e4f1790b7e30 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 4 Jul 2022 16:00:59 +0100 Subject: [PATCH 009/658] do sync lag check --- pixels/behaviours/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 47b3a92..7973611 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -447,11 +447,10 @@ def process_lfp(self): sync_chan = downsampled[:, -1] downsampled = downsampled[:, :-1] - """ if self._lag[rec_num] is None: self.sync_data(rec_num, sync_channel=data[:, -1]) lag_start, lag_end = self._lag[rec_num] - """ + sd = self.processed / recording['lfp_sd'] if sd.exists(): continue @@ -467,12 +466,10 @@ def process_lfp(self): with open(sd, 'w') as fd: json.dump(results, fd) - """ if lag_end < 0: data = data[:lag_end] if lag_start < 0: data = data[- lag_start:] - """ print(f"> Saving median subtracted & downsampled LFP to {output}") # save in .npy format np.save( From 21ba98ea0353e7b00644be50738f0533f74e1236 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 11 Nov 2022 18:00:28 +0000 Subject: [PATCH 010/658] remove empty and redundant units after running spike sorting; start adding waveform extractor from spikeinterface --- pixels/behaviours/base.py | 78 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7973611..234a5dc 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -28,6 +28,7 @@ import spikeinterface as si import spikeinterface.extractors as se import spikeinterface.sorters as ss +import spikeinterface.curation as sc from scipy import interpolate from tables import HDF5ExtError @@ -520,7 +521,21 @@ def sort_spikes(self): concat_rec = si.concatenate_recordings([recording]) probe = pi.read_spikeglx(metadata.as_posix()) concat_rec = concat_rec.set_probe(probe) - ss.run_kilosort3(recording=concat_rec, output_folder=output) + ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) + + # remove empty units + ks3_no_empt = ks3_output.remove_empty_units() + print(f'KS3 found {len(ks3_no_empt.get_unit_ids())} non-empty units.') + + # remove redundant units by keeping minimum shift, highest_amplitude, or + # max_spikes + sc.remove_redundant_units( + ks3_no_empt, + WaveformExtractor, # spike trains realigned using the peak shift in template + duplicate_threshold=0.9, # default is 0.8 + remove_strategy='minimum_shift', # keep unit with best peak alignment + ) + assert 0 def extract_videos(self, force=False): @@ -1713,6 +1728,67 @@ def get_spike_waveforms(self, units=None): df.index = df.index * rate return df + def get_spike_waveforms_si(self, units=None): + import spikeinterface as si + import spikeinterface.extractors as se + + # set chunks + job_kwargs = dict( + n_jobs=10, + chunk_duration="1s", + progress_bar=True, + ) + + # read raw spikeglx recordings from interim + recording = se.SpikeGlxrRecordingExtractor(folder_path=self.interim) + assert 0 + + # extract waveforms + waveforms = si.extract_waveforms( + self.interim, + folder=self.processed / 'sorted_stream_0' /, + load_if_exists=True, # load extracted if available + max_spikes_per_unit=None, + overwrite=False, + **job_kwargs, + ) + + if units: + # defer to getting waveforms for all units + waveforms = self.get_spike_waveforms()[units] + assert list(waveforms.columns.get_level_values("unit").unique()) == list(units) + return waveforms + + units = self.select_units() + + paramspy = self.processed / 'sorted_stream_0' / 'params.py' + if not paramspy.exists(): + raise PixelsError(f"{self.name}: params.py not found") + model = load_model(paramspy) + rec_forms = {} + + for u, unit in enumerate(units): + print(100 * u / len(units), "% complete") + # get the waveforms from only the best channel + spike_ids = model.get_cluster_spikes(unit) + best_chan = model.get_cluster_channels(unit)[0] + u_waveforms = model.get_waveforms(spike_ids, [best_chan]) + if u_waveforms is None: + raise PixelsError(f"{self.name}: unit {unit} - waveforms not read") + rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T) + + assert rec_forms + + df = pd.concat( + rec_forms, + axis=1, + names=['unit', 'spike'] + ) + # convert indexes to ms + rate = 1000 / int(self.spike_meta[0]['imSampRate']) + df.index = df.index * rate + return df + @_cacheable def get_aligned_spike_rate_CI( self, label, event, From 72c20722b7591f56abbf31784039d6e62852b2a9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 11 Nov 2022 19:03:42 +0000 Subject: [PATCH 011/658] save kilosort 3 sorting object; completing waveform extraction --- pixels/behaviours/base.py | 52 +++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 234a5dc..0c4adca 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -521,7 +521,12 @@ def sort_spikes(self): concat_rec = si.concatenate_recordings([recording]) probe = pi.read_spikeglx(metadata.as_posix()) concat_rec = concat_rec.set_probe(probe) - ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) + #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) + ks3_output = ss.run_sorter( + sorter_name='kilosort3', + recording=concat_rec, + output_folder=output, + ) # remove empty units ks3_no_empt = ks3_output.remove_empty_units() @@ -529,12 +534,14 @@ def sort_spikes(self): # remove redundant units by keeping minimum shift, highest_amplitude, or # max_spikes - sc.remove_redundant_units( + ks3_output = sc.remove_redundant_units( ks3_no_empt, WaveformExtractor, # spike trains realigned using the peak shift in template duplicate_threshold=0.9, # default is 0.8 remove_strategy='minimum_shift', # keep unit with best peak alignment ) + # save spikeinterface sorting object for easier loading + ks3_output.save(folder=output / 'saved_si_sorting_obj') assert 0 @@ -1729,9 +1736,7 @@ def get_spike_waveforms(self, units=None): return df def get_spike_waveforms_si(self, units=None): - import spikeinterface as si - import spikeinterface.extractors as se - + streams = {} # set chunks job_kwargs = dict( n_jobs=10, @@ -1739,20 +1744,46 @@ def get_spike_waveforms_si(self, units=None): progress_bar=True, ) - # read raw spikeglx recordings from interim - recording = se.SpikeGlxrRecordingExtractor(folder_path=self.interim) - assert 0 + for rec_num, files in enumerate(self.files): + data_file = self.find_file(files['spike_data']) + assert data_file, f"Spike data not found for {files['spike_data']}." + + stream_id = data_file.as_posix()[-12:-4] + if stream_id not in streams: + metadata = self.find_file(files['spike_meta']) + streams[stream_id] = metadata + + for stream_num, stream in enumerate(streams.items()): + stream_id, metadata = stream + try: + recording = se.SpikeGLXRecordingExtractor(self.interim, stream_id=stream_id) + except ValueError as e: + raise PixelsError( + f"Did the raw data get fully copied to interim? Full error: {e}" + ) + try: + sorting = se.NpzSortingExtractor( + self.processed / 'sorted_stream_{stream_num}/saved_si_sorting_obj', + ) + except ValueError as e: + raise PixelsError( + f"Have you run spike sorting? Full error: {e}" + ) # extract waveforms waveforms = si.extract_waveforms( - self.interim, - folder=self.processed / 'sorted_stream_0' /, + recording=recording, + sorting=sorting, + folder=self.interim/ 'cache', load_if_exists=True, # load extracted if available max_spikes_per_unit=None, overwrite=False, **job_kwargs, ) + assert 0 + """ + # Matt's stuff if units: # defer to getting waveforms for all units waveforms = self.get_spike_waveforms()[units] @@ -1788,6 +1819,7 @@ def get_spike_waveforms_si(self, units=None): rate = 1000 / int(self.spike_meta[0]['imSampRate']) df.index = df.index * rate return df + """ @_cacheable def get_aligned_spike_rate_CI( From 104a2c9fd48826dc72e400e58a1fba8874618d54 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 11 Nov 2022 19:04:12 +0000 Subject: [PATCH 012/658] export_to_phy place holder --- pixels/behaviours/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0c4adca..c4d6b28 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -542,6 +542,7 @@ def sort_spikes(self): ) # save spikeinterface sorting object for easier loading ks3_output.save(folder=output / 'saved_si_sorting_obj') + #TODO: export_to_phy assert 0 From f4ddefce156150fe6f5eb213f7786889da28ac37 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 14 Nov 2022 18:34:38 +0000 Subject: [PATCH 013/658] use spikeinterface to extract waveforms, and use export_to_phy to calculate params for phy --- pixels/behaviours/base.py | 115 ++++++++++---------------------------- 1 file changed, 28 insertions(+), 87 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c4d6b28..4c0a516 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -29,6 +29,7 @@ import spikeinterface.extractors as se import spikeinterface.sorters as ss import spikeinterface.curation as sc +import spikeinterface.exporters as sexp from scipy import interpolate from tables import HDF5ExtError @@ -487,6 +488,12 @@ def sort_spikes(self): Run kilosort spike sorting on raw spike data. """ streams = {} + # set chunks + job_kwargs = dict( + n_jobs=10, # -1: num of job equals num of cores + chunk_duration="1s", + progress_bar=True, + ) for _, files in enumerate(self.files): if len(self.catGT_dir) == 0: @@ -542,7 +549,27 @@ def sort_spikes(self): ) # save spikeinterface sorting object for easier loading ks3_output.save(folder=output / 'saved_si_sorting_obj') - #TODO: export_to_phy + + # extract waveforms + waveforms = si.extract_waveforms( + recording=concat_rec, + sorting=ks3_output, + folder=self.interim/ 'cache', + load_if_exists=True, # load extracted if available + max_spikes_per_unit=None, + overwrite=False, + **job_kwargs, + ) + + # export_to_phy + sexp.export_to_phy( + waveform_extractor=waveforms, + output_folder=output / "phy_ks3", + compute_pc_features=True, + compute_amplitudes=True, + copy_binary=True, + **job_kwargs, + ) assert 0 @@ -1736,92 +1763,6 @@ def get_spike_waveforms(self, units=None): df.index = df.index * rate return df - def get_spike_waveforms_si(self, units=None): - streams = {} - # set chunks - job_kwargs = dict( - n_jobs=10, - chunk_duration="1s", - progress_bar=True, - ) - - for rec_num, files in enumerate(self.files): - data_file = self.find_file(files['spike_data']) - assert data_file, f"Spike data not found for {files['spike_data']}." - - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: - metadata = self.find_file(files['spike_meta']) - streams[stream_id] = metadata - - for stream_num, stream in enumerate(streams.items()): - stream_id, metadata = stream - try: - recording = se.SpikeGLXRecordingExtractor(self.interim, stream_id=stream_id) - except ValueError as e: - raise PixelsError( - f"Did the raw data get fully copied to interim? Full error: {e}" - ) - try: - sorting = se.NpzSortingExtractor( - self.processed / 'sorted_stream_{stream_num}/saved_si_sorting_obj', - ) - except ValueError as e: - raise PixelsError( - f"Have you run spike sorting? Full error: {e}" - ) - - # extract waveforms - waveforms = si.extract_waveforms( - recording=recording, - sorting=sorting, - folder=self.interim/ 'cache', - load_if_exists=True, # load extracted if available - max_spikes_per_unit=None, - overwrite=False, - **job_kwargs, - ) - assert 0 - - """ - # Matt's stuff - if units: - # defer to getting waveforms for all units - waveforms = self.get_spike_waveforms()[units] - assert list(waveforms.columns.get_level_values("unit").unique()) == list(units) - return waveforms - - units = self.select_units() - - paramspy = self.processed / 'sorted_stream_0' / 'params.py' - if not paramspy.exists(): - raise PixelsError(f"{self.name}: params.py not found") - model = load_model(paramspy) - rec_forms = {} - - for u, unit in enumerate(units): - print(100 * u / len(units), "% complete") - # get the waveforms from only the best channel - spike_ids = model.get_cluster_spikes(unit) - best_chan = model.get_cluster_channels(unit)[0] - u_waveforms = model.get_waveforms(spike_ids, [best_chan]) - if u_waveforms is None: - raise PixelsError(f"{self.name}: unit {unit} - waveforms not read") - rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T) - - assert rec_forms - - df = pd.concat( - rec_forms, - axis=1, - names=['unit', 'spike'] - ) - # convert indexes to ms - rate = 1000 / int(self.spike_meta[0]['imSampRate']) - df.index = df.index * rate - return df - """ - @_cacheable def get_aligned_spike_rate_CI( self, label, event, From a0e002c1cb20ecaf08ff079ecc23769ed38a3ddb Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 15 Nov 2022 19:28:31 +0000 Subject: [PATCH 014/658] load sorting and waveforms if already done; compute pc features for ks3 output; add spikeinterface method to extract waveforms. --- pixels/behaviours/base.py | 275 +++++++++++++++++++++++++++++--------- 1 file changed, 209 insertions(+), 66 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4c0a516..7ea11fe 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -495,6 +495,9 @@ def sort_spikes(self): progress_bar=True, ) + #TODO: consider to put catgt here + + # find spike data to sort for _, files in enumerate(self.files): if len(self.catGT_dir) == 0: print(f"> Spike data not found for {files['catGT_ap_data']},\ @@ -515,31 +518,93 @@ def sort_spikes(self): stream_id, metadata = stream try: recording = se.SpikeGLXRecordingExtractor(self.catGT_dir, stream_id=stream_id) + # this recording is filtered + recording.annotate(is_filtered=True) except ValueError as e: raise PixelsError( f"Did the raw data get fully copied to interim? Full error: {e}" ) - print("> Running kilosort") + # find spike sorting output folder if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: output = self.processed / f'sorted_stream_cat_{stream_num}' else: output = self.processed / f'sorted_stream_{stream_num}' - concat_rec = si.concatenate_recordings([recording]) - probe = pi.read_spikeglx(metadata.as_posix()) - concat_rec = concat_rec.set_probe(probe) - #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) - ks3_output = ss.run_sorter( - sorter_name='kilosort3', - recording=concat_rec, - output_folder=output, - ) - # remove empty units - ks3_no_empt = ks3_output.remove_empty_units() - print(f'KS3 found {len(ks3_no_empt.get_unit_ids())} non-empty units.') + #TODO + try: + ks3_output = si.load_extractor(output) + print("> This session is already sorted, now it is loaded.") + assert 0 + except: + print("> Running kilosort") + # concatenate recording segments + concat_rec = si.concatenate_recordings([recording]) + probe = pi.read_spikeglx(metadata.as_posix()) + concat_rec = concat_rec.set_probe(probe) + # annotate spike data is filtered + concat_rec.annotate(is_filtered=True) + + # for testing: get first 5 mins of the recording + fs = concat_rec.get_sampling_frequency() + test = concat_rec.frame_slice( + start_frame=0*fs, + end_frame=300*fs, + ) + test.annotate(is_filtered=True) + # check all annotations + test.get_annotation('is_filtered') + print(test) + + #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) + ks3_output = ss.run_sorter( + sorter_name='kilosort3', + #recording=concat_rec, + recording=test, # for testing + output_folder=output, + #remove_existing_folder=False, + **job_kwargs, + ) + + # remove empty units + ks3_output = ks3_output.remove_empty_units() + print(f'KS3 found {len(ks3_no_empt.get_unit_ids())} non-empty units.') + + """ + #TODO: remove duplicated spikes from spike train, only in >0.96.1 si + ks3_output = sc.remove_duplicated_spikes( + sorting=ks3_no_empt, + censored_period_ms=0.3, #ms + method='keep_first', # keep first spike, remove the second + ) + """ + # save spikeinterface sorting object for easier loading + ks3_output.save(folder=output / 'saved_si_sorting_obj') + + try: + waveforms = si.WaveformExtractor.load_from_folder( + folder=self.interim / 'cache', + sorting=ks3_output, + ) + print("> Waveforms extracted, now it is loaded.") + except: + print("> Waveforms not extracted, extracting now.") + # extract waveforms + waveforms = si.extract_waveforms( + #recording=concat_rec, + recording=test, # for testing + sorting=ks3_output, + folder=self.interim / 'cache', + load_if_exists=True, # load extracted if available + #load_if_exists=False, # re-calculate everytime + max_spikes_per_unit=None, # extract all waveforms + overwrite=False, + **job_kwargs, + ) + assert 0 - # remove redundant units by keeping minimum shift, highest_amplitude, or + """ + # TODO: remove redundant units by keeping minimum shift, highest_amplitude, or # max_spikes ks3_output = sc.remove_redundant_units( ks3_no_empt, @@ -547,27 +612,16 @@ def sort_spikes(self): duplicate_threshold=0.9, # default is 0.8 remove_strategy='minimum_shift', # keep unit with best peak alignment ) - # save spikeinterface sorting object for easier loading - ks3_output.save(folder=output / 'saved_si_sorting_obj') - - # extract waveforms - waveforms = si.extract_waveforms( - recording=concat_rec, - sorting=ks3_output, - folder=self.interim/ 'cache', - load_if_exists=True, # load extracted if available - max_spikes_per_unit=None, - overwrite=False, - **job_kwargs, - ) + """ - # export_to_phy + # TODO sexp.export_to_phy( waveform_extractor=waveforms, output_folder=output / "phy_ks3", compute_pc_features=True, compute_amplitudes=True, copy_binary=True, + remove_if_exists=True, # load if already exported to phy **job_kwargs, ) assert 0 @@ -1723,45 +1777,134 @@ def get_spike_widths(self, units=None): return df @_cacheable - def get_spike_waveforms(self, units=None): - from phylib.io.model import load_model - from phylib.utils.color import selected_cluster_color + def get_spike_waveforms(self, units=None, method='phy'): + if method == 'phy': + from phylib.io.model import load_model + from phylib.utils.color import selected_cluster_color - if units: - # defer to getting waveforms for all units - waveforms = self.get_spike_waveforms()[units] - assert list(waveforms.columns.get_level_values("unit").unique()) == list(units) - return waveforms - - units = self.select_units() - - paramspy = self.processed / 'sorted_stream_0' / 'params.py' - if not paramspy.exists(): - raise PixelsError(f"{self.name}: params.py not found") - model = load_model(paramspy) - rec_forms = {} - - for u, unit in enumerate(units): - print(100 * u / len(units), "% complete") - # get the waveforms from only the best channel - spike_ids = model.get_cluster_spikes(unit) - best_chan = model.get_cluster_channels(unit)[0] - u_waveforms = model.get_waveforms(spike_ids, [best_chan]) - if u_waveforms is None: - raise PixelsError(f"{self.name}: unit {unit} - waveforms not read") - rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T) - - assert rec_forms - - df = pd.concat( - rec_forms, - axis=1, - names=['unit', 'spike'] - ) - # convert indexes to ms - rate = 1000 / int(self.spike_meta[0]['imSampRate']) - df.index = df.index * rate - return df + if units: + # defer to getting waveforms for all units + waveforms = self.get_spike_waveforms()[units] + assert list(waveforms.columns.get_level_values("unit").unique()) == list(units) + return waveforms + + units = self.select_units() + + paramspy = self.processed / 'sorted_stream_0' / 'params.py' + if not paramspy.exists(): + raise PixelsError(f"{self.name}: params.py not found") + model = load_model(paramspy) + rec_forms = {} + + for u, unit in enumerate(units): + print(100 * u / len(units), "% complete") + # get the waveforms from only the best channel + spike_ids = model.get_cluster_spikes(unit) + best_chan = model.get_cluster_channels(unit)[0] + u_waveforms = model.get_waveforms(spike_ids, [best_chan]) + if u_waveforms is None: + raise PixelsError(f"{self.name}: unit {unit} - waveforms not read") + rec_forms[unit] = pd.DataFrame(np.squeeze(u_waveforms).T) + + assert rec_forms + + df = pd.concat( + rec_forms, + axis=1, + names=['unit', 'spike'] + ) + # convert indexes to ms + rate = 1000 / int(self.spike_meta[0]['imSampRate']) + df.index = df.index * rate + return df + + #TODO: implement spikeinterface waveform extraction + elif method == 'spikeinterface': + streams = {} + # set chunks + job_kwargs = dict( + n_jobs=10, # -1: num of job equals num of cores + chunk_duration="1s", + progress_bar=True, + ) + + # load recording and sorting object + for _, files in enumerate(self.files): + if len(self.catGT_dir) == 0: + print(f"> Spike data not found for {files['catGT_ap_data']},\ + \nuse the orignial recording data.") + data_file = self.find_file(files['spike_data']) + metadata = self.find_file(files['spike_meta']) + else: + print("> Use catgt-ed recording") + self.catGT_dir = Path(self.catGT_dir[0]) + data_file = self.catGT_dir / files['catGT_ap_data'] + metadata = self.catGT_dir / files['catGT_ap_meta'] + + stream_id = data_file.as_posix()[-12:-4] + if stream_id not in streams: + streams[stream_id] = metadata + + for stream_num, stream in enumerate(streams.items()): + stream_id, metadata = stream + try: + recording = se.SpikeGLXRecordingExtractor(self.catGT_dir, stream_id=stream_id) + # this recording is filtered + recording.annotate(is_filtered=True) + except ValueError as e: + raise PixelsError( + f"Did the raw data get fully copied to interim? Full error: {e}" + ) + try: + # load sorting object + sorting = si.load_extractor(self.processed) + except ValueError as e: + raise PixelsError( + f"Have you run spike sorting yet? Full error: {e}" + ) + + try: + waveforms = si.WaveformExtractor.load_from_folder( + folder=self.interim / 'cache', + sorting=sorting, + ) + except: + print("> Waveforms not extracted, extracting now.") + + #TODO + if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: + output = self.processed / f'sorted_stream_cat_{stream_num}' + else: + output = self.processed / f'sorted_stream_{stream_num}' + + # for testing: get first 5 mins of the recording + fs = concat_rec.get_sampling_frequency() + test = concat_rec.frame_slice( + start_frame=0*fs, + end_frame=300*fs, + ) + test.annotate(is_filtered=True) + # check all annotations + test.get_annotation('is_filtered') + print(test) + + # extract waveforms + waveforms = si.extract_waveforms( + #recording=concat_rec, + recording=test, # for testing + sorting=ks3_no_empt, # sorting=ks3_output after remove dups + folder=self.interim / 'cache', + #load_if_exists=True, # load extracted if available + load_if_exists=False, # re-calculate everytime + max_spikes_per_unit=None, # extract all waveforms + overwrite=False, + **job_kwargs, + ) + assert 0 + + else: + raise PixelsError(f"{self.name}: waveform extraction method {method} is\ + not implemented!") @_cacheable def get_aligned_spike_rate_CI( From 242c9eaca61342665275b13c563c8bc71e083dea Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 23 Nov 2022 12:26:15 +0000 Subject: [PATCH 015/658] make featureview in phy work --- pixels/behaviours/base.py | 81 ++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 35 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7ea11fe..35d37bf 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -30,6 +30,7 @@ import spikeinterface.sorters as ss import spikeinterface.curation as sc import spikeinterface.exporters as sexp +import spikeinterface.postprocessing as spost from scipy import interpolate from tables import HDF5ExtError @@ -488,7 +489,7 @@ def sort_spikes(self): Run kilosort spike sorting on raw spike data. """ streams = {} - # set chunks + # set chunks for spikeinterface operations job_kwargs = dict( n_jobs=10, # -1: num of job equals num of cores chunk_duration="1s", @@ -501,11 +502,11 @@ def sort_spikes(self): for _, files in enumerate(self.files): if len(self.catGT_dir) == 0: print(f"> Spike data not found for {files['catGT_ap_data']},\ - \nuse the orignial spike data.") + \nuse the orignial spike data.\n") data_file = self.find_file(files['spike_data']) metadata = self.find_file(files['spike_meta']) else: - print("> Sorting catgt-ed spikes") + print("> Sorting catgt-ed spikes\n") self.catGT_dir = Path(self.catGT_dir[0]) data_file = self.catGT_dir / files['catGT_ap_data'] metadata = self.catGT_dir / files['catGT_ap_meta'] @@ -516,35 +517,44 @@ def sort_spikes(self): for stream_num, stream in enumerate(streams.items()): stream_id, metadata = stream + # find spike sorting output folder + if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: + output = self.processed / f'sorted_stream_cat_{stream_num}' + else: + output = self.processed / f'sorted_stream_{stream_num}' + + # check if already sorted and exported + for_phy = output / "phy_ks3" + if not os.path.exists(for_phy) or len(os.listdir(for_phy)) == 0: + print("> Not sorted yet, start spike sorting...\n") + else: + print("> Already sorted and exported, next session...\n") + continue + try: recording = se.SpikeGLXRecordingExtractor(self.catGT_dir, stream_id=stream_id) # this recording is filtered recording.annotate(is_filtered=True) except ValueError as e: raise PixelsError( - f"Did the raw data get fully copied to interim? Full error: {e}" + f"Did the raw data get fully copied to interim? Full error: {e}\n" ) - # find spike sorting output folder - if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: - output = self.processed / f'sorted_stream_cat_{stream_num}' - else: - output = self.processed / f'sorted_stream_{stream_num}' - - #TODO try: - ks3_output = si.load_extractor(output) - print("> This session is already sorted, now it is loaded.") - assert 0 + ks3_output = si.load_extractor(output / + f'saved_si_sorting_obj_{stream_num}') + print("> This session is already sorted, now it is loaded.\n") except: - print("> Running kilosort") + print("> Running kilosort\n") # concatenate recording segments concat_rec = si.concatenate_recordings([recording]) probe = pi.read_spikeglx(metadata.as_posix()) concat_rec = concat_rec.set_probe(probe) # annotate spike data is filtered concat_rec.annotate(is_filtered=True) + print(f"> Now is sorting: \n{concat_rec}\n") + """ # for testing: get first 5 mins of the recording fs = concat_rec.get_sampling_frequency() test = concat_rec.frame_slice( @@ -555,12 +565,12 @@ def sort_spikes(self): # check all annotations test.get_annotation('is_filtered') print(test) + """ #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) ks3_output = ss.run_sorter( sorter_name='kilosort3', - #recording=concat_rec, - recording=test, # for testing + recording=concat_rec, #recording=test, # for testing output_folder=output, #remove_existing_folder=False, **job_kwargs, @@ -568,7 +578,7 @@ def sort_spikes(self): # remove empty units ks3_output = ks3_output.remove_empty_units() - print(f'KS3 found {len(ks3_no_empt.get_unit_ids())} non-empty units.') + print(f"KS3 found {len(ks3_output.get_unit_ids())} non-empty units.\n") """ #TODO: remove duplicated spikes from spike train, only in >0.96.1 si @@ -586,45 +596,46 @@ def sort_spikes(self): folder=self.interim / 'cache', sorting=ks3_output, ) - print("> Waveforms extracted, now it is loaded.") + print("> Waveforms extracted, now it is loaded.\n") except: - print("> Waveforms not extracted, extracting now.") + print("> Waveforms not extracted, extracting now.\n") # extract waveforms waveforms = si.extract_waveforms( - #recording=concat_rec, - recording=test, # for testing + recording=concat_rec, #recording=test, # for testing sorting=ks3_output, folder=self.interim / 'cache', - load_if_exists=True, # load extracted if available - #load_if_exists=False, # re-calculate everytime - max_spikes_per_unit=None, # extract all waveforms - overwrite=False, + #load_if_exists=True, # load extracted if available + load_if_exists=False, # re-calculate everytime + max_spikes_per_unit=1000, # None will extract all waveforms + #overwrite=False, + overwrite=True, **job_kwargs, ) - assert 0 """ # TODO: remove redundant units by keeping minimum shift, highest_amplitude, or # max_spikes ks3_output = sc.remove_redundant_units( - ks3_no_empt, - WaveformExtractor, # spike trains realigned using the peak shift in template + waveforms, # spike trains realigned using the peak shift in template duplicate_threshold=0.9, # default is 0.8 remove_strategy='minimum_shift', # keep unit with best peak alignment ) """ - - # TODO + # export to phy, with pc feature calculated. + # copy recording.dat to output so that individual waveforms can be + # seen in waveformview. sexp.export_to_phy( waveform_extractor=waveforms, - output_folder=output / "phy_ks3", - compute_pc_features=True, + output_folder=for_phy, + compute_pc_features=True, # pca compute_amplitudes=True, copy_binary=True, - remove_if_exists=True, # load if already exported to phy + #remove_if_exists=True, # overwrite everytime + remove_if_exists=False, # load if already exists **job_kwargs, ) - assert 0 + print(f"> Parameters for manual curation saved to {for_phy}.\ + \n DO NOT FORGET TO COPY cluster_KSLabel.tsv from {output} to {for_phy}.\n") def extract_videos(self, force=False): From 68dcd213a91b1b5262985cf324d298d51cc6c37a Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 7 Dec 2022 14:06:41 +0000 Subject: [PATCH 016/658] use spikeinterface export_to_phy func for old data too; make sure cluster info can be found --- pixels/behaviours/base.py | 156 +++++++++++++++++++++++++------------- 1 file changed, 102 insertions(+), 54 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 35d37bf..b72f384 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -396,6 +396,7 @@ def process_spikes(self): continue data_file = self.find_file(recording['spike_data']) + """ orig_rate = self.spike_meta[rec_num]['imSampRate'] num_chans = self.spike_meta[rec_num]['nSavedChans'] @@ -423,6 +424,7 @@ def process_spikes(self): data = data[- lag_start:] data = pd.DataFrame(data[:, :-1]) ioutils.write_hdf5(output, data) + """ def process_lfp(self): """ @@ -484,7 +486,7 @@ def process_lfp(self): #ioutils.write_hdf5(output, downsampled) - def sort_spikes(self): + def sort_spikes(self, old=False): """ Run kilosort spike sorting on raw spike data. """ @@ -525,7 +527,7 @@ def sort_spikes(self): # check if already sorted and exported for_phy = output / "phy_ks3" - if not os.path.exists(for_phy) or len(os.listdir(for_phy)) == 0: + if not os.path.exists(for_phy) or len(os.listdir(for_phy)) == 1: print("> Not sorted yet, start spike sorting...\n") else: print("> Already sorted and exported, next session...\n") @@ -540,56 +542,67 @@ def sort_spikes(self): f"Did the raw data get fully copied to interim? Full error: {e}\n" ) - try: - ks3_output = si.load_extractor(output / - f'saved_si_sorting_obj_{stream_num}') - print("> This session is already sorted, now it is loaded.\n") - except: - print("> Running kilosort\n") - # concatenate recording segments - concat_rec = si.concatenate_recordings([recording]) - probe = pi.read_spikeglx(metadata.as_posix()) - concat_rec = concat_rec.set_probe(probe) - # annotate spike data is filtered - concat_rec.annotate(is_filtered=True) - print(f"> Now is sorting: \n{concat_rec}\n") - - """ - # for testing: get first 5 mins of the recording - fs = concat_rec.get_sampling_frequency() - test = concat_rec.frame_slice( - start_frame=0*fs, - end_frame=300*fs, - ) - test.annotate(is_filtered=True) - # check all annotations - test.get_annotation('is_filtered') - print(test) - """ - - #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) - ks3_output = ss.run_sorter( - sorter_name='kilosort3', - recording=concat_rec, #recording=test, # for testing - output_folder=output, - #remove_existing_folder=False, - **job_kwargs, - ) - + if old: + print("\n> loading old kilosort 3 results to spikeinterface") + sorting = se.read_kilosort(old_ks_output_dir) # avoid re-sort old # remove empty units - ks3_output = ks3_output.remove_empty_units() - print(f"KS3 found {len(ks3_output.get_unit_ids())} non-empty units.\n") - - """ - #TODO: remove duplicated spikes from spike train, only in >0.96.1 si - ks3_output = sc.remove_duplicated_spikes( - sorting=ks3_no_empt, - censored_period_ms=0.3, #ms - method='keep_first', # keep first spike, remove the second - ) - """ - # save spikeinterface sorting object for easier loading - ks3_output.save(folder=output / 'saved_si_sorting_obj') + ks3_output = sorting.remove_empty_units() + print(f"> KS3 removed\ + {len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ + non-empty units.\n") + else: + try: + ks3_output = si.load_extractor(output / + f'saved_si_sorting_obj_{stream_num}') + print("> This session is already sorted, now it is loaded.\n") + except: + print("> Running kilosort\n") + # concatenate recording segments + concat_rec = si.concatenate_recordings([recording]) + probe = pi.read_spikeglx(metadata.as_posix()) + concat_rec = concat_rec.set_probe(probe) + # annotate spike data is filtered + concat_rec.annotate(is_filtered=True) + print(f"> Now is sorting: \n{concat_rec}\n") + + """ + # for testing: get first 5 mins of the recording + fs = concat_rec.get_sampling_frequency() + test = concat_rec.frame_slice( + start_frame=0*fs, + end_frame=300*fs, + ) + test.annotate(is_filtered=True) + # check all annotations + test.get_annotation('is_filtered') + print(test) + """ + + #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) + sorting = ss.run_sorter( + sorter_name='kilosort3', + recording=concat_rec, #recording=test, # for testing + output_folder=output, + #remove_existing_folder=False, + **job_kwargs, + ) + + # remove empty units + ks3_output = sorting.remove_empty_units() + print(f"> KS3 removed\ + {len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ + non-empty units.\n") + + """ + #TODO: remove duplicated spikes from spike train, only in >0.96.1 si + ks3_output = sc.remove_duplicated_spikes( + sorting=ks3_no_empt, + censored_period_ms=0.3, #ms + method='keep_first', # keep first spike, remove the second + ) + """ + # save spikeinterface sorting object for easier loading + ks3_output.save(folder=output / 'saved_si_sorting_obj') try: waveforms = si.WaveformExtractor.load_from_folder( @@ -634,8 +647,33 @@ def sort_spikes(self): remove_if_exists=False, # load if already exists **job_kwargs, ) - print(f"> Parameters for manual curation saved to {for_phy}.\ - \n DO NOT FORGET TO COPY cluster_KSLabel.tsv from {output} to {for_phy}.\n") + print(f"> Parameters for manual curation saved to {for_phy}.\n") + + correct_kslabels = for_phy / "cluster_KSLabel.tsv", + if os.path.exists(correct_kslabels): + print(f"\nCorrect KS labels already saved in {correct_kslabels}. Next session.\n") + continue + + print("\n> Getting all KS labels...") + all_ks_labels = pd.read_csv( + output / "cluster_KSLabel.tsv", + sep='\t', + ) + print("\n> Finding cluster ids from spikeinterface output...") + new_clus_ids = pd.read_csv( + for_phy / "cluster_si_unit_ids.tsv", + sep='\t', + ) + units = new_clus_ids.si_unit_id.to_list() + + print("\n> Saving correct ks labels...") + selected_kslabels = all_ks_labels.iloc[units].reset_index(drop=True) + selected_kslabels.loc[:, "cluster_id"] = [i for i in range(new_clus_ids.shape[0])] + selected_kslabels.to_csv( + correct_kslabels, + sep='\t', + index=False, + ) def extract_videos(self, force=False): @@ -1747,7 +1785,17 @@ def align_clips(self, label, event, video_match, duration=1): def get_cluster_info(self): if self._cluster_info is None: - info_file = self.processed / 'sorted_stream_0' / 'cluster_info.tsv' + if len(self.catGT_dir) == 0: + print(f"> cluster info not found for {files['catGT_ap_data']},\ + \nfind cluster info in old place.\n") + info_file = self.processed / 'sorted_stream_0' / 'cluster_info.tsv' + else: + if not (self.processed / 'sorted_stream_cat_0' / 'phy_ks3').exists(): + print("> getting cluster info from original kilosort output folder\n") + info_file = self.processed / 'sorted_stream_cat_0' / 'cluster_info.tsv' + else: + print("> getting cluster info from spikeinterface export folder\n") + info_file = self.processed / 'sorted_stream_cat_0' / 'phy_ks3' / 'cluster_info.tsv' try: info = pd.read_csv(info_file, sep='\t') except FileNotFoundError: From f72c00bcce0f52d6a35e46747c6181ad72d8ce5a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 9 Dec 2022 15:46:21 +0000 Subject: [PATCH 017/658] do not copy lfp meta to interim so that catgt can create lfp stream from full-band ap.bin; optimise sort_spikes to incorporate old sorted streams --- pixels/behaviours/base.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b72f384..b005a0c 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -165,7 +165,7 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): ioutils.read_meta(self.find_file(f['spike_meta'])) for f in self.files ] self.lfp_meta = [ - ioutils.read_meta(self.find_file(f['lfp_meta'])) for f in self.files + ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files ] # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache @@ -503,12 +503,12 @@ def sort_spikes(self, old=False): # find spike data to sort for _, files in enumerate(self.files): if len(self.catGT_dir) == 0: - print(f"> Spike data not found for {files['catGT_ap_data']},\ + print(f"\n> Spike data not found for {files['catGT_ap_data']},\ \nuse the orignial spike data.\n") data_file = self.find_file(files['spike_data']) metadata = self.find_file(files['spike_meta']) else: - print("> Sorting catgt-ed spikes\n") + print("\n> Sorting catgt-ed spikes\n") self.catGT_dir = Path(self.catGT_dir[0]) data_file = self.catGT_dir / files['catGT_ap_data'] metadata = self.catGT_dir / files['catGT_ap_meta'] @@ -527,8 +527,8 @@ def sort_spikes(self, old=False): # check if already sorted and exported for_phy = output / "phy_ks3" - if not os.path.exists(for_phy) or len(os.listdir(for_phy)) == 1: - print("> Not sorted yet, start spike sorting...\n") + if not for_phy.exists() or not len(os.listdir(for_phy)) > 1: + print("> Not sorted or exported yet, start from spike sorting...\n") else: print("> Already sorted and exported, next session...\n") continue @@ -548,8 +548,8 @@ def sort_spikes(self, old=False): # remove empty units ks3_output = sorting.remove_empty_units() print(f"> KS3 removed\ - {len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ - non-empty units.\n") + \n{len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ + empty units.\n") else: try: ks3_output = si.load_extractor(output / @@ -590,8 +590,8 @@ def sort_spikes(self, old=False): # remove empty units ks3_output = sorting.remove_empty_units() print(f"> KS3 removed\ - {len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ - non-empty units.\n") + \n{len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ + empty units.\n") """ #TODO: remove duplicated spikes from spike train, only in >0.96.1 si @@ -604,6 +604,8 @@ def sort_spikes(self, old=False): # save spikeinterface sorting object for easier loading ks3_output.save(folder=output / 'saved_si_sorting_obj') + #TODO: toggle load_if_exists=True & overwrite=False should replace + #...load_from_folder. try: waveforms = si.WaveformExtractor.load_from_folder( folder=self.interim / 'cache', @@ -637,6 +639,7 @@ def sort_spikes(self, old=False): # export to phy, with pc feature calculated. # copy recording.dat to output so that individual waveforms can be # seen in waveformview. + print("> Exporting parameters for phy...\n") sexp.export_to_phy( waveform_extractor=waveforms, output_folder=for_phy, @@ -649,8 +652,8 @@ def sort_spikes(self, old=False): ) print(f"> Parameters for manual curation saved to {for_phy}.\n") - correct_kslabels = for_phy / "cluster_KSLabel.tsv", - if os.path.exists(correct_kslabels): + correct_kslabels = for_phy / "cluster_KSLabel.tsv" + if correct_kslabels.exists(): print(f"\nCorrect KS labels already saved in {correct_kslabels}. Next session.\n") continue From 3bc0637f24ce29364563bb5705d6871fa786bb95 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 16 Dec 2022 00:35:43 +0000 Subject: [PATCH 018/658] start adding run_catgt; get cluster info from the right place; optimise sort spikes --- pixels/behaviours/base.py | 102 +++++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 23 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b005a0c..cb9f111 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -486,7 +486,61 @@ def process_lfp(self): #ioutils.write_hdf5(output, downsampled) - def sort_spikes(self, old=False): + def run_catgt(self, CatGT_app=None, args=None) -> None: + """ + This func performs CatGT on copied AP data in the interim. + + params + ==== + data_dir: path, dir to interim data and catgt output. + + catgt_app: path, dir to catgt software. + + args: str, arguments in catgt. + default is None. + """ + if CatGT_app == None: + CatGT_app = "~/CatGT3.4" + # move cwd to catgt + os.chdir(CatGT_app) + + assert 0 + for rec_num, recording in enumerate(self.files): + self.find_file(recording['spike_data']) + + # reset catgt args for current session + session_args = None + + if len(self.catGT_dir) != 0: + if len(os.listdir(self.catGT_dir[0])) != 0: + print(f"\nCatGT already performed on ap data of {self.name}. Next session.\n") + continue + + #TODO: finish this here so that catgt can run together with sorting + print(f"> running CatGT on ap data of {self.name}") + #_dir = self.interim + + if args == None: + args = f"-no_run_fld\ + -g=0,9\ + -t=0,9\ + -prb=0\ + -ap\ + -lf\ + -apfilter=butter,12,300,9000\ + -lffilter=butter,12,0.5,300\ + -gblcar\ + -gfix=0.2,0.1,0.02" + + session_args = f"-dir={self.interim} -run={self.name} -dest={self.interim} " + args + print(f"\ncatgt args of {self.name}: \n{session_args}") + + subprocess.run( ['./run_catgt.sh', session_args]) + + assert 0 + + + def sort_spikes(self, CatGT_app=None, old=False): """ Run kilosort spike sorting on raw spike data. """ @@ -499,6 +553,8 @@ def sort_spikes(self, old=False): ) #TODO: consider to put catgt here + #if not CatGT_app == None: + # self.run_catgt(CatGT_app=CatGT_app) # find spike data to sort for _, files in enumerate(self.files): @@ -542,6 +598,13 @@ def sort_spikes(self, old=False): f"Did the raw data get fully copied to interim? Full error: {e}\n" ) + # concatenate recording segments + concat_rec = si.concatenate_recordings([recording]) + probe = pi.read_spikeglx(metadata.as_posix()) + concat_rec = concat_rec.set_probe(probe) + # annotate spike data is filtered + concat_rec.annotate(is_filtered=True) + if old: print("\n> loading old kilosort 3 results to spikeinterface") sorting = se.read_kilosort(old_ks_output_dir) # avoid re-sort old @@ -552,18 +615,8 @@ def sort_spikes(self, old=False): empty units.\n") else: try: - ks3_output = si.load_extractor(output / - f'saved_si_sorting_obj_{stream_num}') + ks3_output = si.load_extractor(output / 'saved_si_sorting_obj') print("> This session is already sorted, now it is loaded.\n") - except: - print("> Running kilosort\n") - # concatenate recording segments - concat_rec = si.concatenate_recordings([recording]) - probe = pi.read_spikeglx(metadata.as_posix()) - concat_rec = concat_rec.set_probe(probe) - # annotate spike data is filtered - concat_rec.annotate(is_filtered=True) - print(f"> Now is sorting: \n{concat_rec}\n") """ # for testing: get first 5 mins of the recording @@ -578,6 +631,9 @@ def sort_spikes(self, old=False): print(test) """ + except: + print("> Running kilosort\n") + print(f"> Now is sorting: \n{concat_rec}\n") #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) sorting = ss.run_sorter( sorter_name='kilosort3', @@ -639,7 +695,7 @@ def sort_spikes(self, old=False): # export to phy, with pc feature calculated. # copy recording.dat to output so that individual waveforms can be # seen in waveformview. - print("> Exporting parameters for phy...\n") + print("\n> Exporting parameters for phy...\n") sexp.export_to_phy( waveform_extractor=waveforms, output_folder=for_phy, @@ -1788,17 +1844,17 @@ def align_clips(self, label, event, video_match, duration=1): def get_cluster_info(self): if self._cluster_info is None: - if len(self.catGT_dir) == 0: - print(f"> cluster info not found for {files['catGT_ap_data']},\ - \nfind cluster info in old place.\n") - info_file = self.processed / 'sorted_stream_0' / 'cluster_info.tsv' - else: - if not (self.processed / 'sorted_stream_cat_0' / 'phy_ks3').exists(): - print("> getting cluster info from original kilosort output folder\n") - info_file = self.processed / 'sorted_stream_cat_0' / 'cluster_info.tsv' + if not (self.processed / 'sorted_stream_cat_0' / 'phy_ks3' / + 'cluster_info.tsv').exists(): + if not (self.processed / 'sorted_stream_cat_0' / + 'cluster_info.tsv').exists(): + info_file = self.processed / 'sorted_stream_0' / 'cluster_info.tsv' else: - print("> getting cluster info from spikeinterface export folder\n") - info_file = self.processed / 'sorted_stream_cat_0' / 'phy_ks3' / 'cluster_info.tsv' + info_file = self.processed / 'sorted_stream_cat_0' / 'cluster_info.tsv' + else: + info_file = self.processed / 'sorted_stream_cat_0' / 'phy_ks3' / 'cluster_info.tsv' + + print(f"> got cluster info at {info_file}\n") try: info = pd.read_csv(info_file, sep='\t') except FileNotFoundError: From bd89879de6b847276440b14ae471804e0a94710c Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 19 Dec 2022 12:07:28 +0000 Subject: [PATCH 019/658] run catgt with sort spike --- pixels/behaviours/base.py | 48 ++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index cb9f111..4c61145 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -14,6 +14,7 @@ import tarfile import tempfile import re +import subprocess from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path @@ -144,7 +145,7 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): else: self.interim = Path(interim_dir) / self.name - self.catGT_dir = glob.glob( + self.CatGT_dir = glob.glob( str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' ) @@ -504,20 +505,20 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: # move cwd to catgt os.chdir(CatGT_app) - assert 0 for rec_num, recording in enumerate(self.files): + # copy spike data to interim self.find_file(recording['spike_data']) # reset catgt args for current session session_args = None - if len(self.catGT_dir) != 0: - if len(os.listdir(self.catGT_dir[0])) != 0: + if len(self.CatGT_dir) != 0: + if len(os.listdir(self.CatGT_dir[0])) != 0: print(f"\nCatGT already performed on ap data of {self.name}. Next session.\n") continue #TODO: finish this here so that catgt can run together with sorting - print(f"> running CatGT on ap data of {self.name}") + print(f"> Running CatGT on ap data of {self.name}") #_dir = self.interim if args == None: @@ -529,6 +530,7 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: -lf\ -apfilter=butter,12,300,9000\ -lffilter=butter,12,0.5,300\ + -xd=2,0,384,6,500\ -gblcar\ -gfix=0.2,0.1,0.02" @@ -537,8 +539,6 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: subprocess.run( ['./run_catgt.sh', session_args]) - assert 0 - def sort_spikes(self, CatGT_app=None, old=False): """ @@ -552,22 +552,18 @@ def sort_spikes(self, CatGT_app=None, old=False): progress_bar=True, ) - #TODO: consider to put catgt here - #if not CatGT_app == None: - # self.run_catgt(CatGT_app=CatGT_app) - - # find spike data to sort for _, files in enumerate(self.files): - if len(self.catGT_dir) == 0: - print(f"\n> Spike data not found for {files['catGT_ap_data']},\ - \nuse the orignial spike data.\n") + if not CatGT_app == None: + self.run_catgt(CatGT_app=CatGT_app) + + print("\n> Sorting catgt-ed spikes\n") + self.CatGT_dir = Path(self.CatGT_dir[0]) + data_file = self.CatGT_dir / files['catGT_ap_data'] + metadata = self.CatGT_dir / files['catGT_ap_meta'] + else: + print(f"\n> using the orignial spike data.\n") data_file = self.find_file(files['spike_data']) metadata = self.find_file(files['spike_meta']) - else: - print("\n> Sorting catgt-ed spikes\n") - self.catGT_dir = Path(self.catGT_dir[0]) - data_file = self.catGT_dir / files['catGT_ap_data'] - metadata = self.catGT_dir / files['catGT_ap_meta'] stream_id = data_file.as_posix()[-12:-4] if stream_id not in streams: @@ -590,7 +586,7 @@ def sort_spikes(self, CatGT_app=None, old=False): continue try: - recording = se.SpikeGLXRecordingExtractor(self.catGT_dir, stream_id=stream_id) + recording = se.SpikeGLXRecordingExtractor(self.CatGT_dir, stream_id=stream_id) # this recording is filtered recording.annotate(is_filtered=True) except ValueError as e: @@ -1948,16 +1944,16 @@ def get_spike_waveforms(self, units=None, method='phy'): # load recording and sorting object for _, files in enumerate(self.files): - if len(self.catGT_dir) == 0: + if len(self.CatGT_dir) == 0: print(f"> Spike data not found for {files['catGT_ap_data']},\ \nuse the orignial recording data.") data_file = self.find_file(files['spike_data']) metadata = self.find_file(files['spike_meta']) else: print("> Use catgt-ed recording") - self.catGT_dir = Path(self.catGT_dir[0]) - data_file = self.catGT_dir / files['catGT_ap_data'] - metadata = self.catGT_dir / files['catGT_ap_meta'] + self.CatGT_dir = Path(self.CatGT_dir[0]) + data_file = self.CatGT_dir / files['catGT_ap_data'] + metadata = self.CatGT_dir / files['catGT_ap_meta'] stream_id = data_file.as_posix()[-12:-4] if stream_id not in streams: @@ -1966,7 +1962,7 @@ def get_spike_waveforms(self, units=None, method='phy'): for stream_num, stream in enumerate(streams.items()): stream_id, metadata = stream try: - recording = se.SpikeGLXRecordingExtractor(self.catGT_dir, stream_id=stream_id) + recording = se.SpikeGLXRecordingExtractor(self.CatGT_dir, stream_id=stream_id) # this recording is filtered recording.annotate(is_filtered=True) except ValueError as e: From 57dd6dfa403dfbb93117d6d83d3cad09c9a51c54 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 19 Dec 2022 12:07:45 +0000 Subject: [PATCH 020/658] add catgt option --- pixels/experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 7346c85..70ed042 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -116,12 +116,12 @@ def process_spikes(self): .format(session.name, i + 1, len(self.sessions))) session.process_spikes() - def sort_spikes(self): + def sort_spikes(self, CatGT_app=None): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): print(">>>>> Sorting spikes for session {} ({} / {})" .format(session.name, i + 1, len(self.sessions))) - session.sort_spikes() + session.sort_spikes(CatGT_app=CatGT_app) def assess_noise(self): """ From 096f866d8f2cd690b19ff0c7d67605da4d6a7208 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 19 Dec 2022 12:10:48 +0000 Subject: [PATCH 021/658] start adding action labels for passive viewing --- pixels/behaviours/passive_viewing.py | 514 +++++++++++++++++++++++++++ 1 file changed, 514 insertions(+) create mode 100644 pixels/behaviours/passive_viewing.py diff --git a/pixels/behaviours/passive_viewing.py b/pixels/behaviours/passive_viewing.py new file mode 100644 index 0000000..a2fc215 --- /dev/null +++ b/pixels/behaviours/passive_viewing.py @@ -0,0 +1,514 @@ +""" +This module provides passive viewing task specific operations. +""" + +from __future__ import annotations + +import pickle + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from reach.session import Outcomes, Targets + +from pixels import Experiment, PixelsError +from pixels import signal, ioutils +from pixels.behaviours import Behaviour + + +class ActionLabels: + """ + These actions cover all possible trial types. + 'g0.04' and 'g0.16' correspond to the grating trial's spatial frequency. This + means trials with all temporal frequencies are included here. + 'nat_movie' and 'cage_movie' correspond to movie trials, where a 60s of natural + or animal home cage movie is played. + + for ESCN00-03, start of each of these trials is marked by a rise of 500 ms TTL + pulse. + Within each grating trial, each orientation grating shows for 2s, their + beginnings and ends are both marked by a rise of 500 ms TTL pulse. + + *dec 16th 2022: in the near future, there will be 'nat_image' trial, where + + To align trials to more than one action type they can be bitwise OR'd i.e. + `miss_left | miss_right` will match all miss trials. + """ + # gratings 0.04 spatial freq + g_0.04_0 = 1 << 0 + g_0.04_30 = 1 << 1 + g_0.04_60 = 1 << 2 + g_0.04_90 = 1 << 3 + g_0.04_120 = 1 << 4 + g_0.04_150 = 1 << 5 + g_0.04_180 = 1 << 6 + g_0.04_210 = 1 << 7 + g_0.04_240 = 1 << 8 + g_0.04_270 = 1 << 9 + g_0.04_300 = 1 << 10 + g_0.04_330 = 1 << 11 + + # gratings 0.16 spatial freq + g_0.16_0 = 1 << 12 + g_0.16_30 = 1 << 13 + g_0.16_60 = 1 << 14 + g_0.16_90 = 1 << 15 + g_0.16_120 = 1 << 16 + g_0.16_150 = 1 << 17 + g_0.16_180 = 1 << 18 + g_0.16_210 = 1 << 19 + g_0.16_240 = 1 << 20 + g_0.16_270 = 1 << 21 + g_0.16_300 = 1 << 22 + g_0.16_330 = 1 << 23 + + # movies + nat_movie = 1 << 24 + cage_movie = 1 << 25 + + # iti marks + black = 1 << 26 + gray = 1 << 27 + + #TODO: natural images + + +class Events: + led_on = 1 << 0 + led_off = 1 << 1 + + # Timepoints determined from motion tracking + reach_onset = 1 << 2 + slit_in = 1 << 3 + grasp = 1 << 4 + slit_out = 1 << 5 + subsequent_slit_in = 1 << 6 # The SECOND full reach on a clean correct trial only + subsequent_grasp = 1 << 7 + subsequent_slit_out = 1 << 8 + + +# These are used to convert the trial data into Actions and Events +_side_map = { + Targets.LEFT: "left", + Targets.RIGHT: "right", +} + +_action_map = { + Outcomes.MISSED: "miss", + Outcomes.CORRECT: "correct", + Outcomes.INCORRECT: "incorrect", +} + + + +class Reach(Behaviour): + def _preprocess_behaviour(self, rec_num, behavioural_data): + # Correction for sessions where sync channel interfered with LED channel + if behavioural_data["/'ReachLEDs'/'0'"].min() < -2: + behavioural_data["/'ReachLEDs'/'0'"] = behavioural_data["/'ReachLEDs'/'0'"] \ + + 0.5 * behavioural_data["/'NpxlSync_Signal'/'0'"] + + behavioural_data = signal.binarise(behavioural_data) + action_labels = np.zeros((len(behavioural_data), 2), dtype=np.uint64) + + try: + cue_leds = behavioural_data["/'ReachLEDs'/'0'"].values + except KeyError: + # some early recordings still used this key + cue_leds = behavioural_data["/'Back_Sensor'/'0'"].values + + led_onsets = np.where((cue_leds[:-1] == 0) & (cue_leds[1:] == 1))[0] + led_offsets = np.where((cue_leds[:-1] == 1) & (cue_leds[1:] == 0))[0] + action_labels[led_onsets, 1] += Events.led_on + action_labels[led_offsets, 1] += Events.led_off + metadata = self.metadata[rec_num] + + # QA: Check that the JSON and TDMS data have the same number of trials + if len(led_onsets) != len(metadata["trials"]): + # If they do not have the same number, perhaps the TDMS was stopped too early + meta_onsets = np.array([t["start"] for t in metadata["trials"]]) * 1000 + meta_onsets = (meta_onsets - meta_onsets[0] + led_onsets[0]).astype(int) + if meta_onsets[-1] > len(cue_leds): + # TDMS stopped too early, continue anyway. + i = -1 + while meta_onsets[i] > len(cue_leds): + metadata["trials"].pop() + i -= 1 + assert len(led_onsets) == len(metadata["trials"]) + else: + # If you have come to debug and see why this error was raised, try: + # led_onsets - meta_onsets[:len(led_onsets)] # This might show the problem + # meta_onsets - led_onsets[:len(meta_onsets)] # This might show the problem + # Then just patch a fix here: + if self.name == "211027_VR49" and rec_num == 1: + del metadata["trials"][52] # Maybe cable fell out of DAQ input? + else: + raise PixelsError( + f"{self.name}: Mantis and Raspberry Pi behavioural " + "data have different no. of trials" + ) + + # QA: Last offset not found in tdms data? + if len(led_offsets) < len(led_onsets): + last_trial = self.metadata[rec_num]['trials'][-1] + if "end" in last_trial: + # Take known offset from metadata + offset = led_onsets[-1] + (last_trial['end'] - last_trial['start']) * 1000 + led_offsets = np.append(led_offsets, int(offset)) + else: + # If not possible, just remove last onset + led_onsets = led_onsets[:-1] + metadata["trials"].pop() + assert len(led_offsets) == len(led_onsets) + + # QA: For some reason, sometimes the final trial metadata doesn't include the + # final led-off even though it is detectable in the TDMS data. + elif len(led_offsets) == len(led_onsets): + # Not sure how to deal with this if led_offsets and led_onsets differ in length + if len(metadata["trials"][-1]) == 1 and "start" in metadata["trials"][-1]: + # Remove it, because we would have to check the video to get all of the + # information about the trial, and it's too complicated. + metadata["trials"].pop() + led_onsets = led_onsets[:-1] + led_offsets = led_offsets[:-1] + + # QA: Check that the cue durations (mostly) match between JSON and TDMS data + # This compares them at 10s of milliseconds resolution + cue_durations_tdms = (led_offsets - led_onsets) / 100 + cue_durations_json = np.array( + [t['end'] - t['start'] for t in metadata['trials']] + ) * 10 + error = sum( + (cue_durations_tdms - cue_durations_json).round() != 0 + ) / len(led_onsets) + if error > 0.05: + raise PixelsError( + f"{self.name}: Mantis and Raspberry Pi behavioural data have mismatching trial data." + ) + + return behavioural_data, action_labels, led_onsets + + def _extract_action_labels(self, rec_num, behavioural_data, plot=False): + behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(rec_num, behavioural_data) + + for i, trial in enumerate(self.metadata[rec_num]["trials"]): + side = _side_map[trial["spout"]] + outcome = trial["outcome"] + if outcome in _action_map: + action = _action_map[trial["outcome"]] + action_labels[led_onsets[i], 0] += getattr(ActionLabels, f"{action}_{side}") + + if plot: + plt.clf() + _, axes = plt.subplots(4, 1, sharex=True, sharey=True) + axes[0].plot(back_sensor_signal) + if "/'Back_Sensor'/'0'" in behavioural_data: + axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) + else: + axes[1].plot(behavioural_data["/'ReachCue_LEDs'/'0'"].values) + axes[2].plot(action_labels[:, 0]) + axes[3].plot(action_labels[:, 1]) + plt.plot(action_labels[:, 1]) + plt.show() + + return action_labels + + def draw_slit_thresholds(self, project: str, force: bool = False): + """ + Draw lines on the slits using EasyROI. If ROIs already exist, skip. + + Parameters + ========== + project : str + The DLC project i.e. name/prefix of the camera. + + force : bool + If true, we will draw new lines even if the output file exists. + + """ + # Only needed for this method + import cv2 + import EasyROI + + output = self.processed / f"slit_thresholds_{project}.pickle" + + if output.exists() and not force: + print(self.name, "- slits drawn already.") + return + + # Let's take the average between the first and last frames of the whole session. + videos = [] + + for recording in self.files: + for v, video in enumerate(recording.get("camera_data", [])): + if project in video.stem: + avi = self.interim / video.with_suffix('.avi') + if not avi.exists(): + meta = recording['camera_meta'][v] + ioutils.tdms_to_video( + self.find_file(video, copy=False), + self.find_file(meta), + avi, + ) + if not avi.exists(): + raise PixelsError(f"Path {avi} should exist but doesn't... discuss.") + videos.append(avi.as_posix()) + + if not videos: + raise PixelsError("No videos were found to draw slits on.") + + first_frame = ioutils.load_video_frame(videos[0], 1) + last_duration = ioutils.get_video_dimensions(videos[-1])[2] + last_frame = ioutils.load_video_frame(videos[-1], last_duration - 1) + + average_frame = np.concatenate( + [first_frame[..., None], last_frame[..., None]], + axis=2, + ).mean(axis=2) + average_frame = np.squeeze(average_frame) / 255 + + # Interactively draw ROI + global _roi_helper + if _roi_helper is None: + # Ugly but we can only have one instance of this + _roi_helper = EasyROI.EasyROI(verbose=False) + lines = _roi_helper.draw_line(average_frame, 2) + cv2.destroyAllWindows() # Needed otherwise EasyROI errors + + # Save a copy of the frame with ROIs to PNG file + png = output.with_suffix(".png") + copy = EasyROI.visualize_line(average_frame, lines, color=(255, 0, 0)) + plt.imsave(png, copy, cmap='gray') + + # Save lines to file + with output.open('wb') as fd: + pickle.dump(lines['roi'], fd) + + def inject_slit_crossings(self): + """ + Take the lines drawn from `draw_slit_thresholds` above, get the reach + coordinates from DLC output, identify the timepoints when successful reaches + crossed the lines, and add the `reach_onset` event to those timepoints in the + action labels. Also identify which trials need clearing up, i.e. those with + multiple reaches or have failed motion tracking, and exclude those. + """ + lines = {} + projects = ("LeftCam", "RightCam") + + for project in projects: + line_file = self.processed / f"slit_thresholds_{project}.pickle" + + if not line_file.exists(): + print(self.name, "- Lines not drawn for session.") + return + + with line_file.open("rb") as f: + proj_lines = pickle.load(f) + + lines[project] = { + tt:[pd.Series(p, index=["x", "y"]) for p in points.values()] + for tt, points in proj_lines.items() + } + + action_labels = self.get_action_labels() + event = Events.led_off + + # https://bryceboe.com/2006/10/23/line-segment-intersection-algorithm + def ccw(A, B, C): + return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) + + for tt, action in enumerate( + (ActionLabels.correct_left, ActionLabels.correct_right), + ): + data = {} + trajectories = {} + + for project in projects: + proj_data = self.align_trials( + action, + event, + "motion_tracking", + duration=6, + dlc_project=project, + ) + proj_traj, = check_scorers(proj_data) + data[project] = proj_traj + trajectories[project] = get_reach_trajectories(proj_traj)[0] + + for rec_num, recording in enumerate(self.files): + actions = action_labels[rec_num][:, 0] + events = action_labels[rec_num][:, 1] + trial_starts = np.where(np.bitwise_and(actions, action))[0] + + for t, start in enumerate(trial_starts): + centre = np.where(np.bitwise_and(events[start:start + 6000], event))[0] + if len(centre) == 0: + raise PixelsError('Action labels probably miscalculated') + centre = start + centre[0] + # centre is index of this rec's grasp + onsets = [] + for project, motion in trajectories.items(): + left_hand = motion["left_hand_median"][t][:0] + right_hand = motion["right_hand_median"][t][:0] + pt1, pt2 = lines[project][tt] + + x_l = left_hand.iloc[-10:]["x"].mean() + x_r = right_hand.iloc[-10:]["x"].mean() + hand = right_hand + if project == "LeftCam": + if x_l > x_r: + hand = left_hand + else: + if x_r > x_l: + hand = left_hand + + segments = zip( + hand.iloc[::-1].iterrows(), + hand.iloc[-2::-1].iterrows(), + ) + + for (end, ptend), (start, ptsta) in segments: + if ( + ccw(pt1, ptsta, ptend) != ccw(pt2, ptsta, ptend) and + ccw(pt1, pt2, ptsta) != ccw(pt1, pt2, ptend) + ): + # These lines intersect + #print("x from ", ptsta.x, " to ", ptend.x) + break + #if ptend.y > 300 or ptsta.y > 300: + # assert 0 + onsets.append(start) + + onset = max(onsets) + onset_timepoint = round(centre + (onset * 1000)) + events[onset_timepoint] |= Events.reach_onset + + output = self.processed / recording['action_labels'] + np.save(output, action_labels[rec_num]) + + +global _roi_helper +_roi_helper = None + + +class VisualOnly(Reach): + def _extract_action_labels(self, behavioural_data, plot=False): + behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(behavioural_data) + + for i, trial in enumerate(self.metadata["trials"]): + label = "naive_" + _side_map[trial["spout"]] + "_" + if trial["cue_duration"] > 125: + label += "long" + else: + label += "short" + action_labels[led_onsets[i], 0] += getattr(ActionLabels, label) + + return action_labels + + +def get_reach_velocities(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Get the velocity curves for the provided reach trajectories. + """ + results = [] + + for df in dfs: + df = df.copy() + deltas = np.square(df.iloc[1:].values - df.iloc[:-1].values) + # Fill the start with a row of zeros - each value is delta in previous 1 ms + deltas = np.append(np.zeros((1, deltas.shape[1])), deltas, axis=0) + deltas = np.sqrt(deltas[:, ::2] + deltas[:, 1::2]) + df = df.drop([c for c in df.columns if "y" in c], axis=1) + df = df.rename({"x": "delta"}, axis='columns') + df.values[:] = deltas + results.append(df) + + return tuple(results) + + +def get_reach_trajectories(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Get the median centre point of the hand coordinates - i.e. for the labels for each + of the four digits and hand centre for both paws. + """ + assert dfs + bodyparts = get_body_parts(dfs[0]) + right_paw = [p for p in bodyparts if p.startswith("right")] + left_paw = [p for p in bodyparts if p.startswith("left")] + + trajectories_l = [] + trajectories_r = [] + + for df in dfs: + per_ses_l = [] + per_ses_r = [] + + # Ugly hack so this function works on single sessions + if "session" not in df.columns.names: + df = pd.concat([df], keys=[0], axis=1) + sessions = [0] + single_session = True + else: + single_session = False + sessions = df.columns.get_level_values("session").unique() + + for s in sessions: + per_trial_l = [] + per_trial_r = [] + + trials = df[s].columns.get_level_values("trial").unique() + for t in trials: + tdf = df[s][t] + left = pd.concat((tdf[p] for p in left_paw), keys=left_paw) + right = pd.concat((tdf[p] for p in right_paw), keys=right_paw) + per_trial_l.append(left.groupby(level=1).median()) + per_trial_r.append(right.groupby(level=1).median()) + + per_ses_l.append(pd.concat(per_trial_l, axis=1, keys=trials)) + per_ses_r.append(pd.concat(per_trial_r, axis=1, keys=trials)) + + trajectories_l.append(pd.concat(per_ses_l, axis=1, keys=sessions)) + trajectories_r.append(pd.concat(per_ses_r, axis=1, keys=sessions)) + + if single_session: + return tuple( + pd.concat( + [trajectories_l[i][0], trajectories_r[i][0]], + axis=1, + keys=["left_hand_median", "right_hand_median"], + ) + for i in range(len(dfs)) + ) + return tuple( + pd.concat( + [trajectories_l[i], trajectories_r[i]], + axis=1, + keys=["left_hand_median", "right_hand_median"], + ) + for i in range(len(dfs)) + ) + + +def check_scorers(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Checks that the scorers are identical for all data in the dataframes. These are the + dataframes as returned from Exp.align_trials for motion_tracking data. + + It returns the dataframes with the scorer index level removed. + """ + scorers = set( + s + for df in dfs + for s in df.columns.get_level_values("scorer").unique() + ) + + assert len(scorers) == 1, scorers + return tuple(df.droplevel("scorer", axis=1) for df in dfs) + + +def get_body_parts(df: pd.DataFrame) -> list[str]: + """ + Get the list of body part labels. + """ + return df.columns.get_level_values("bodyparts").unique() From 52bafff3d3b3cdf61ffe9c8f4ca440998182f586 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 Dec 2022 20:09:31 +0000 Subject: [PATCH 022/658] define kilosort output folder; get probe depth from depth info file --- pixels/behaviours/base.py | 49 +++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4c61145..10e2490 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -138,6 +138,24 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self.raw = self.data_dir / 'raw' / self.name self.processed = self.data_dir / 'processed' / self.name + + ks_output = glob.glob( + str(self.processed) +'/' + f'sorted_stream_cat_[0-9]' + ) + if not len(ks_output) == 0: + ks_output = Path(ks_output[0]) + if not ((ks_output / 'phy_ks3').exists() and + len(os.listdir(ks_output / 'phy_ks3'))>17): + #if not (ks_output.exists() and + #len(os.listdir(ks_output / ks_output))>17): + self.ks_output = ks_output + else: + self.ks_output = ks_output / 'phy_ks3' + else: + self.ks_output = Path(glob.glob( + str(self.processed) +'/' + f'sorted_stream_[0-9]' + )[0]) + self.files = ioutils.get_data_files(self.raw, name) if interim_dir is None: @@ -163,7 +181,7 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self.drop_data() self.spike_meta = [ - ioutils.read_meta(self.find_file(f['spike_meta'])) for f in self.files + ioutils.read_meta(self.find_file(f['spike_meta'], copy=False)) for f in self.files ] self.lfp_meta = [ ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files @@ -214,10 +232,18 @@ def get_probe_depth(self): """ depth_file = self.processed / 'depth.txt' if not depth_file.exists(): - msg = f": Can't load probe depth: please add it in um to processed/{self.name}/depth.txt" + depth_file = self.processed / self.files[0]["depth_info"] + + if not depth_file.exists(): + msg = f": Can't load probe depth: please add it in um to\ + \nprocessed/{self.name}/depth.txt, or save it with other depth related\ + \ninfo in {self.processed / self.files[0]['depth_info']}." raise PixelsError(msg) - with depth_file.open() as fd: - return [float(line) for line in fd.readlines()] + if Path(depth_file).suffix == ".txt": + with depth_file.open() as fd: + return [float(line) for line in fd.readlines()] + elif Path(depth_file).suffix == ".json": + return [json.load(open(depth_file, mode="r"))["clustering"]] def find_file(self, name: str, copy: bool=True) -> Optional[Path]: """ @@ -1840,17 +1866,9 @@ def align_clips(self, label, event, video_match, duration=1): def get_cluster_info(self): if self._cluster_info is None: - if not (self.processed / 'sorted_stream_cat_0' / 'phy_ks3' / - 'cluster_info.tsv').exists(): - if not (self.processed / 'sorted_stream_cat_0' / - 'cluster_info.tsv').exists(): - info_file = self.processed / 'sorted_stream_0' / 'cluster_info.tsv' - else: - info_file = self.processed / 'sorted_stream_cat_0' / 'cluster_info.tsv' - else: - info_file = self.processed / 'sorted_stream_cat_0' / 'phy_ks3' / 'cluster_info.tsv' - + info_file = self.ks_output / 'cluster_info.tsv' print(f"> got cluster info at {info_file}\n") + try: info = pd.read_csv(info_file, sep='\t') except FileNotFoundError: @@ -1904,7 +1922,8 @@ def get_spike_waveforms(self, units=None, method='phy'): units = self.select_units() - paramspy = self.processed / 'sorted_stream_0' / 'params.py' + #paramspy = self.processed / 'sorted_stream_0' / 'params.py' + paramspy = self.ks_output / 'params.py' if not paramspy.exists(): raise PixelsError(f"{self.name}: params.py not found") model = load_model(paramspy) From 0b02677b03aded8006fef349668bf1178e93ce37 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 Dec 2022 20:16:57 +0000 Subject: [PATCH 023/658] finish defining visual stim labels and events --- pixels/behaviours/passive_viewing.py | 124 ++++++++++++++++----------- 1 file changed, 72 insertions(+), 52 deletions(-) diff --git a/pixels/behaviours/passive_viewing.py b/pixels/behaviours/passive_viewing.py index a2fc215..f107339 100644 --- a/pixels/behaviours/passive_viewing.py +++ b/pixels/behaviours/passive_viewing.py @@ -17,76 +17,95 @@ from pixels.behaviours import Behaviour -class ActionLabels: +class VisStimLabels: """ - These actions cover all possible trial types. - 'g0.04' and 'g0.16' correspond to the grating trial's spatial frequency. This - means trials with all temporal frequencies are included here. + These visual stim labels cover all possible trial types. + 'sf_004' and 'sf_016' correspond to the grating trial's spatial frequency. + This means trials with all temporal frequencies are included here. + 'left' and 'right' notes the direction where the grating is drifting towards. + Gratings have six possible orientations in total, 30 deg aparts. 'nat_movie' and 'cage_movie' correspond to movie trials, where a 60s of natural or animal home cage movie is played. - for ESCN00-03, start of each of these trials is marked by a rise of 500 ms TTL + for ESCN00 to 03, start of each of these trials is marked by a rise of 500 ms TTL pulse. Within each grating trial, each orientation grating shows for 2s, their beginnings and ends are both marked by a rise of 500 ms TTL pulse. - *dec 16th 2022: in the near future, there will be 'nat_image' trial, where + *dec 16th 2022: in the near future, there will be 'nat_image' trial, where... To align trials to more than one action type they can be bitwise OR'd i.e. `miss_left | miss_right` will match all miss trials. """ - # gratings 0.04 spatial freq - g_0.04_0 = 1 << 0 - g_0.04_30 = 1 << 1 - g_0.04_60 = 1 << 2 - g_0.04_90 = 1 << 3 - g_0.04_120 = 1 << 4 - g_0.04_150 = 1 << 5 - g_0.04_180 = 1 << 6 - g_0.04_210 = 1 << 7 - g_0.04_240 = 1 << 8 - g_0.04_270 = 1 << 9 - g_0.04_300 = 1 << 10 - g_0.04_330 = 1 << 11 - - # gratings 0.16 spatial freq - g_0.16_0 = 1 << 12 - g_0.16_30 = 1 << 13 - g_0.16_60 = 1 << 14 - g_0.16_90 = 1 << 15 - g_0.16_120 = 1 << 16 - g_0.16_150 = 1 << 17 - g_0.16_180 = 1 << 18 - g_0.16_210 = 1 << 19 - g_0.16_240 = 1 << 20 - g_0.16_270 = 1 << 21 - g_0.16_300 = 1 << 22 - g_0.16_330 = 1 << 23 + # gratings 0.04 & 0.16 spatial frequency + sf_004 = 1 << 0 + sf_016 = 1 << 1 + + # directions + left = 1 << 2 + right = 1 << 3 + + # orientations + g_0 = 1 << 4 + g_30 = 1 << 5 + g_60 = 1 << 6 + g_90 = 1 << 7 + g_120 = 1 << 8 + g_150 = 1 << 9 + + # iti marks + black = 1 << 10 + gray = 1 << 11 # movies - nat_movie = 1 << 24 - cage_movie = 1 << 25 + nat_movie = 1 << 12 + cage_movie = 1 << 13 - # iti marks - black = 1 << 26 - gray = 1 << 27 + """ + # spatial freq 0.04 directions 0:30:150 + g_004_0 = sf_004 & g_0 & left + g_004_30 = sf_004 & g_30 & left + g_004_60 = sf_004 & g_60 & left + g_004_90 = sf_004 & g_90 & left + g_004_120 = sf_004 & g_120 & left + g_004_150 = sf_004 & g_150 & left + + # spatial freq 0.04 directions 180:30:330 + g_004_180 = sf_004 & g_0 & right + g_004_210 = sf_004 & g_30 & right + g_004_240 = sf_004 & g_60 & right + g_004_270 = sf_004 & g_90 & right + g_004_300 = sf_004 & g_120 & right + g_004_330 = sf_004 & g_150 & right + + # spatial freq 0.16 directions 0:30:150 + g_016_0 = sf_016 & g_0 & left + g_016_30 = sf_016 & g_30 & left + g_016_60 = sf_016 & g_60 & left + g_016_90 = sf_016 & g_90 & left + g_016_120 = sf_016 & g_120 & left + g_016_150 = sf_016 & g_150 & left + + # spatial freq 0.16 directions 180:30:330 + g_016_180 = sf_016 & g_0 & right + g_016_210 = sf_016 & g_30 & right + g_016_240 = sf_016 & g_60 & right + g_016_270 = sf_016 & g_90 & right + g_016_300 = sf_016 & g_120 & right + g_016_330 = sf_016 & g_150 & right + + # spatial freq 0.04 + g_004 = sf_004 & (g_0 | )& (left | right) + g_004 = g_004_0 | g_004_30 | g_004_60 | g_004_90 | + + # spatial freq 0.16 + """ #TODO: natural images - class Events: - led_on = 1 << 0 - led_off = 1 << 1 - - # Timepoints determined from motion tracking - reach_onset = 1 << 2 - slit_in = 1 << 3 - grasp = 1 << 4 - slit_out = 1 << 5 - subsequent_slit_in = 1 << 6 # The SECOND full reach on a clean correct trial only - subsequent_grasp = 1 << 7 - subsequent_slit_out = 1 << 8 - + start = 1 << 0 + end = 1 << 1 # These are used to convert the trial data into Actions and Events _side_map = { @@ -102,7 +121,8 @@ class Events: -class Reach(Behaviour): +class PassiveViewing(Behaviour): + #TODO: continue here, see how to map behaviour metadata def _preprocess_behaviour(self, rec_num, behavioural_data): # Correction for sessions where sync channel interfered with LED channel if behavioural_data["/'ReachLEDs'/'0'"].min() < -2: From 27e5f6fdfcf37845843b96aeb9cf8077563a7e3c Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 23 Dec 2022 16:42:46 +0000 Subject: [PATCH 024/658] add func to get waveform metrics for unit type clustering --- pixels/behaviours/base.py | 101 +++++++++++++++++++++++++++++++++++++- pixels/experiment.py | 20 ++++++++ 2 files changed, 119 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 10e2490..973a7cb 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -699,7 +699,9 @@ def sort_spikes(self, CatGT_app=None, old=False): folder=self.interim / 'cache', #load_if_exists=True, # load extracted if available load_if_exists=False, # re-calculate everytime - max_spikes_per_unit=1000, # None will extract all waveforms + max_spikes_per_unit=500, # None will extract all waveforms + ms_before=2.0, # time before trough + ms_after=2.0, # time after trough #overwrite=False, overwrite=True, **job_kwargs, @@ -1930,7 +1932,7 @@ def get_spike_waveforms(self, units=None, method='phy'): rec_forms = {} for u, unit in enumerate(units): - print(100 * u / len(units), "% complete") + print(f"{round(100 * u / len(units), 2)}% complete") # get the waveforms from only the best channel spike_ids = model.get_cluster_spikes(unit) best_chan = model.get_cluster_channels(unit)[0] @@ -2039,6 +2041,101 @@ def get_spike_waveforms(self, units=None, method='phy'): raise PixelsError(f"{self.name}: waveform extraction method {method} is\ not implemented!") + + @_cacheable + def get_waveform_metrics(self, units=None, window=20): + """ + This func is a work-around of spikeinterface's equivalent: + https://github.com/SpikeInterface/spikeinterface/blob/master/spikeinterface/postprocessing/template_metrics.py. + dec 23rd 2022: motivation to write this function is that spikeinterface + 0.96.1 cannot load sorting object from `export_to_phy` output folder, i.e., i + cannot get updated clusters/units and their waveforms, which is a huge + problem for subsequent analyses e.g. unit type clustering. + + To learn more about waveform metrics, see + https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms + and https://journals.physiology.org/doi/full/10.1152/jn.00680.2018. + + """ + if units: + # Always defer to getting waveform metrics for all good units, so we only + # ever have to calculate metrics for each once. + wave_metrics = self.get_waveform_metrics() + return wave_metrics.loc[wave_metrics.unit.isin(units)] + + columns = ["unit", "trough_to_peak", "trough_peak_ratio", "half_width", + "repolarisation_slope", "recovery_slope"] + print(f"> Calculating waveform metrics {columns[1:]}...\n") + + waveforms = self.get_spike_waveforms() + units = waveforms.columns.get_level_values('unit').unique() + + output = {} + for i, unit in enumerate(units): + metrics = [] + mean_waveform = waveforms[unit].mean(axis=1) + + # time between trough to peak, in ms + trough_idx = np.argmin(mean_waveform) + peak_idx = trough_idx + np.argmax(mean_waveform.iloc[trough_idx:]) + if peak_idx == 0: + raise PixelsError(f"> Cannot find peak in mean waveform.") + if trough_idx == 0: + raise PixelsError(f"> Cannot find trough in mean waveform.") + trough_to_peak = mean_waveform.index[peak_idx] - mean_waveform.index[trough_idx] + metrics.append(trough_to_peak) + + # trough to peak ratio + trough_peak_ratio = mean_waveform.iloc[peak_idx] / mean_waveform.iloc[trough_idx] + metrics.append(trough_peak_ratio) + + # spike half width, in ms + half_amp = mean_waveform.iloc[trough_idx] / 2 + idx_pre_half = np.where(mean_waveform.iloc[:trough_idx] < half_amp) + idx_post_half = np.where(mean_waveform.iloc[trough_idx:] < half_amp) + # last occurence of mean waveform amp lower than half amp, before trough + time_pre_half = mean_waveform.iloc[idx_pre_half[0] - 1].index[0] + # first occurence of mean waveform amp lower than half amp, after trough + time_post_half = mean_waveform.iloc[idx_post_half[0] + 1 + + trough_idx].index[-1] + half_width = time_post_half - time_pre_half + metrics.append(half_width) + + # repolarisation slope + returns = np.where(mean_waveform.iloc[trough_idx:] >= 0) + trough_idx + if len(returns[0]) == 0: + raise PixelsError(f"> The mean waveformrns never returned to baseline?") + return_idx = returns[0][0] + if return_idx - trough_idx < 3: + raise PixelsError(f"> The mean waveform returns to baseline too quickly,\ + \ndoes not make sense...") + repo_period = mean_waveform.iloc[trough_idx:return_idx] + repo_slope = scipy.stats.linregress( + x=repo_period.index.values, # time in ms + y=repo_period.values, # amp + ).slope + metrics.append(repo_slope) + + # recovery slope during user-defined recovery period + recovery_end_idx = peak_idx + window + recovery_end_idx = np.min([recovery_end_idx, mean_waveform.shape[0]]) + reco_period = mean_waveform.iloc[peak_idx:recovery_end_idx] + reco_slope = scipy.stats.linregress( + x=reco_period.index.values, # time in ms + y=reco_period.values, # amp + ).slope + metrics.append(reco_slope) + + # save metrics in output dictionary, key is unit id + output[unit] = metrics + + # save all template metrics as dataframe + df = pd.DataFrame(output).T.reset_index() + df.columns = columns + + return df + + @_cacheable def get_aligned_spike_rate_CI( self, label, event, diff --git a/pixels/experiment.py b/pixels/experiment.py index 70ed042..e1d1a05 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -299,6 +299,26 @@ def get_spike_waveforms(self, units=None): ) return df + def get_waveform_metrics(self, units=None): + """ + Get waveform metrics of mean waveform for units matching the specified criteria. + """ + waveform_metrics = {} + + for i, session in enumerate(self.sessions): + if units: + if units[i]: + waveform_metrics[i] = session.get_waveform_metrics(units=units[i]) + else: + waveform_metrics[i] = session.get_waveform_metrics() + + df = pd.concat( + waveform_metrics.values(), axis=1, copy=False, + keys=waveform_metrics.keys(), + names=["session"] + ) + return df + def get_aligned_spike_rate_CI(self, *args, units=None, **kwargs): """ Get the confidence intervals of the mean firing rates within a window aligned to From 575ece84236a1937131b51eee62e4753a242cc6c Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 27 Dec 2022 12:20:44 +0000 Subject: [PATCH 025/658] define waveform extraction time period; include units with depths equal to min_depth; resolve waveform metrics calculation bugs; makes unit id int64 --- pixels/behaviours/base.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 973a7cb..ef255d0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -701,7 +701,7 @@ def sort_spikes(self, CatGT_app=None, old=False): load_if_exists=False, # re-calculate everytime max_spikes_per_unit=500, # None will extract all waveforms ms_before=2.0, # time before trough - ms_after=2.0, # time after trough + ms_after=3.0, # time after trough #overwrite=False, overwrite=True, **job_kwargs, @@ -1558,7 +1558,7 @@ def select_units( # and that are within the specified depth range if min_depth is not None: - if probe_depth - unit_info['depth'] < min_depth: + if probe_depth - unit_info['depth'] <= min_depth: continue if max_depth is not None: if probe_depth - unit_info['depth'] > max_depth: @@ -2079,9 +2079,9 @@ def get_waveform_metrics(self, units=None, window=20): trough_idx = np.argmin(mean_waveform) peak_idx = trough_idx + np.argmax(mean_waveform.iloc[trough_idx:]) if peak_idx == 0: - raise PixelsError(f"> Cannot find peak in mean waveform.") + raise PixelsError(f"> Cannot find peak in mean waveform.\n") if trough_idx == 0: - raise PixelsError(f"> Cannot find trough in mean waveform.") + raise PixelsError(f"> Cannot find trough in mean waveform.\n") trough_to_peak = mean_waveform.index[peak_idx] - mean_waveform.index[trough_idx] metrics.append(trough_to_peak) @@ -2094,7 +2094,11 @@ def get_waveform_metrics(self, units=None, window=20): idx_pre_half = np.where(mean_waveform.iloc[:trough_idx] < half_amp) idx_post_half = np.where(mean_waveform.iloc[trough_idx:] < half_amp) # last occurence of mean waveform amp lower than half amp, before trough - time_pre_half = mean_waveform.iloc[idx_pre_half[0] - 1].index[0] + if len(idx_pre_half[0]) == 0: + idx_pre_half = trough_idx - 1 + time_pre_half = mean_waveform.index[idx_pre_half] + else: + time_pre_half = mean_waveform.iloc[idx_pre_half[0] - 1].index[0] # first occurence of mean waveform amp lower than half amp, after trough time_post_half = mean_waveform.iloc[idx_post_half[0] + 1 + trough_idx].index[-1] @@ -2104,11 +2108,13 @@ def get_waveform_metrics(self, units=None, window=20): # repolarisation slope returns = np.where(mean_waveform.iloc[trough_idx:] >= 0) + trough_idx if len(returns[0]) == 0: - raise PixelsError(f"> The mean waveformrns never returned to baseline?") - return_idx = returns[0][0] - if return_idx - trough_idx < 3: - raise PixelsError(f"> The mean waveform returns to baseline too quickly,\ - \ndoes not make sense...") + print(f"> The mean waveformrns never returned to baseline?\n") + return_idx = mean_waveform.shape[0] - 1 + else: + return_idx = returns[0][0] + if return_idx - trough_idx < 2: + raise PixelsError(f"> The mean waveform returns to baseline too quickly,\ + \ndoes not make sense...\n") repo_period = mean_waveform.iloc[trough_idx:return_idx] repo_slope = scipy.stats.linregress( x=repo_period.index.values, # time in ms @@ -2132,6 +2138,8 @@ def get_waveform_metrics(self, units=None, window=20): # save all template metrics as dataframe df = pd.DataFrame(output).T.reset_index() df.columns = columns + dtype = {"unit": int} + df = df.astype(dtype) return df From af53ed693d6d61d7924a835bb97b25b9cd661b9b Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 20 Jan 2023 00:49:19 +0000 Subject: [PATCH 026/658] add func to make load recording easier; try normalise average waveform before compute metrics --- pixels/behaviours/base.py | 203 +++++++++++++++++++++++++------------- 1 file changed, 132 insertions(+), 71 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index ef255d0..1ec6e32 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -178,6 +178,7 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self._lag = None self._use_cache = True self._cluster_info = None + self._good_unit_info = None self.drop_data() self.spike_meta = [ @@ -566,6 +567,59 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: subprocess.run( ['./run_catgt.sh', session_args]) + def load_recording(self): + try: + recording = si.load_extractor(self.interim / 'cache/recording.json') + concat_rec = recording + output = os.path.dirname(self.ks_output) + return recording, output + + except: + for _, files in enumerate(self.files): + try: + print("\n> Getting catgt-ed recording...") + self.CatGT_dir = Path(self.CatGT_dir[0]) + data_dir = self.CatGT_dir + data_file = data_dir / files['catGT_ap_data'] + metadata = data_dir / files['catGT_ap_meta'] + except: + print(f"\n> Getting the orignial recording...") + data_file = self.find_file(files['spike_data']) + metadata = self.find_file(files['spike_meta']) + + assert 0 + stream_id = data_file.as_posix()[-12:-4] + if stream_id not in streams: + streams[stream_id] = metadata + + for stream_num, stream in enumerate(streams.items()): + stream_id, metadata = stream + # find spike sorting output folder + if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: + output = self.processed / f'sorted_stream_cat_{stream_num}' + else: + output = self.processed / f'sorted_stream_{stream_num}' + + try: + recording = se.SpikeGLXRecordingExtractor(self.CatGT_dir, stream_id=stream_id) + except ValueError as e: + raise PixelsError( + f"Did the raw data get fully copied to interim? Full error: {e}\n" + ) + + # this recording is filtered + recording.annotate(is_filtered=True) + + # concatenate recording segments + concat_rec = si.concatenate_recordings([recording]) + probe = pi.read_spikeglx(metadata.as_posix()) + concat_rec = concat_rec.set_probe(probe) + # annotate spike data is filtered + concat_rec.annotate(is_filtered=True) + + return concat_rec, output + + def sort_spikes(self, CatGT_app=None, old=False): """ Run kilosort spike sorting on raw spike data. @@ -578,6 +632,10 @@ def sort_spikes(self, CatGT_app=None, old=False): progress_bar=True, ) + concat_rec, output = self.load_recording() + + assert 0 + #TODO: see if ks can run normally now using load_recording() for _, files in enumerate(self.files): if not CatGT_app == None: self.run_catgt(CatGT_app=CatGT_app) @@ -1105,6 +1163,7 @@ def process_motion_index(self, video_match): # Get MIs avi = self.interim / video.with_suffix('.avi') + #TODO: how to load recording rec_rois, roi_file = ses_rois[(rec_num, v)] rec_mi = signal.motion_index(avi, rec_rois) @@ -1879,6 +1938,19 @@ def get_cluster_info(self): self._cluster_info = info return self._cluster_info + def get_good_units_info(self): + if self._good_unit_info is None: + info_file = self.interim / 'good_units_info.tsv' + print(f"> got good unit info at {info_file}\n") + + try: + info = pd.read_csv(info_file, sep='\t') + except FileNotFoundError: + msg = ": Can't load cluster info. Did you export good unit info for this session yet?" + raise PixelsError(self.name + msg) + self._good_unit_info = info + return self._good_unit_info + @_cacheable def get_spike_widths(self, units=None): if units: @@ -1912,6 +1984,12 @@ def get_spike_widths(self, units=None): @_cacheable def get_spike_waveforms(self, units=None, method='phy'): + """ + Extracts waveforms of spikes. + method: str, name of selected method. + 'phy' (default) + 'spikeinterface' + """ if method == 'phy': from phylib.io.model import load_model from phylib.utils.color import selected_cluster_color @@ -1955,87 +2033,64 @@ def get_spike_waveforms(self, units=None, method='phy'): #TODO: implement spikeinterface waveform extraction elif method == 'spikeinterface': - streams = {} # set chunks job_kwargs = dict( n_jobs=10, # -1: num of job equals num of cores chunk_duration="1s", progress_bar=True, ) + recording, _ = self.load_recording() + try: + sorting = se.read_kilosort(self.ks_output) + except ValueError as e: + raise PixelsError( + f"Can't load sorting object. Did you delete cluster_info.csv? Full error: {e}\n" + ) - # load recording and sorting object - for _, files in enumerate(self.files): - if len(self.CatGT_dir) == 0: - print(f"> Spike data not found for {files['catGT_ap_data']},\ - \nuse the orignial recording data.") - data_file = self.find_file(files['spike_data']) - metadata = self.find_file(files['spike_meta']) + # check last modified time of cache, and create time of ks_output + try: + template_cache_mod_time = os.path.getmtime(self.interim / + 'cache/templates_average.npy') + ks_mod_time = os.path.getmtime(self.ks_output / 'cluster_info.tsv') + assert template_cache_mod_time < ks_mod_time + check = True # re-extract waveforms + print("> Re-extracting waveforms since kilosort output is newer.") + except: + if 'template_cache_mod_time' in locals(): + print("> Loading existing waveforms.") + check = False # load existing waveforms else: - print("> Use catgt-ed recording") - self.CatGT_dir = Path(self.CatGT_dir[0]) - data_file = self.CatGT_dir / files['catGT_ap_data'] - metadata = self.CatGT_dir / files['catGT_ap_meta'] + print("> Extracting waveforms since they are not extracted.") + check = True # re-extract waveforms - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: - streams[stream_id] = metadata - - for stream_num, stream in enumerate(streams.items()): - stream_id, metadata = stream - try: - recording = se.SpikeGLXRecordingExtractor(self.CatGT_dir, stream_id=stream_id) - # this recording is filtered - recording.annotate(is_filtered=True) - except ValueError as e: - raise PixelsError( - f"Did the raw data get fully copied to interim? Full error: {e}" - ) - try: - # load sorting object - sorting = si.load_extractor(self.processed) - except ValueError as e: - raise PixelsError( - f"Have you run spike sorting yet? Full error: {e}" - ) - - try: - waveforms = si.WaveformExtractor.load_from_folder( - folder=self.interim / 'cache', - sorting=sorting, - ) - except: - print("> Waveforms not extracted, extracting now.") - - #TODO - if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: - output = self.processed / f'sorted_stream_cat_{stream_num}' - else: - output = self.processed / f'sorted_stream_{stream_num}' + """ + # for testing: get first 5 mins of the recording + fs = concat_rec.get_sampling_frequency() + test = concat_rec.frame_slice( + start_frame=0*fs, + end_frame=300*fs, + ) + test.annotate(is_filtered=True) + # check all annotations + test.get_annotation('is_filtered') + print(test) + """ - # for testing: get first 5 mins of the recording - fs = concat_rec.get_sampling_frequency() - test = concat_rec.frame_slice( - start_frame=0*fs, - end_frame=300*fs, - ) - test.annotate(is_filtered=True) - # check all annotations - test.get_annotation('is_filtered') - print(test) + # extract waveforms + waveforms = si.extract_waveforms( + recording=recording, + sorting=sorting, + folder=self.interim / 'cache', + load_if_exists=not(check), # maybe re-extracted + max_spikes_per_unit=500, # None will extract all waveforms + ms_before=2.0, # time before trough + ms_after=3.0, # time after trough + overwrite=check, # overwrite depends on check + **job_kwargs, + ) + #TODO: use cache to export the results? - # extract waveforms - waveforms = si.extract_waveforms( - #recording=concat_rec, - recording=test, # for testing - sorting=ks3_no_empt, # sorting=ks3_output after remove dups - folder=self.interim / 'cache', - #load_if_exists=True, # load extracted if available - load_if_exists=False, # re-calculate everytime - max_spikes_per_unit=None, # extract all waveforms - overwrite=False, - **job_kwargs, - ) - assert 0 + return waveforms else: raise PixelsError(f"{self.name}: waveform extraction method {method} is\ @@ -2068,12 +2123,18 @@ def get_waveform_metrics(self, units=None, window=20): print(f"> Calculating waveform metrics {columns[1:]}...\n") waveforms = self.get_spike_waveforms() + # remove nan values + waveforms.dropna() units = waveforms.columns.get_level_values('unit').unique() output = {} for i, unit in enumerate(units): metrics = [] - mean_waveform = waveforms[unit].mean(axis=1) + mean_waveform = waveforms[unit].median(axis=1) + # normalise mean waveform to remove variance caused by distance! + mean_waveform = mean_waveform / mean_waveform.abs().max() + #TODO: test! also can try clustering on normalised meann waveform + assert 0 # time between trough to peak, in ms trough_idx = np.argmin(mean_waveform) From 53ebbeeb10ae8fff3b2b248f172ed967d3811543 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 20 Jan 2023 00:51:52 +0000 Subject: [PATCH 027/658] get info table for selected sessions; allows pooling waveform metrics of units for each animal across sessions --- pixels/experiment.py | 68 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 11 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index e1d1a05..2585749 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -259,6 +259,36 @@ def get_cluster_info(self): """ return [s.get_cluster_info() for s in self.sessions] + def get_good_units_info(self): + """ + Get some basic high-level information for each good unit. This is mostly just the + information seen in the table in phy, plus their region. + """ + info = {} + + for m, mouse in enumerate(self.mouse_ids): + info[mouse] = {} + for i, session in enumerate(self.sessions): + if mouse in session.name: + info[mouse][i] = session.get_good_units_info() + + long_df = pd.concat( + info[mouse], + axis=0, + names=["session", "unit_idx"], + ) + info[mouse] = long_df + + info_per_mouse = info + + info_pooled = pd.concat( + info, axis=0, copy=False, + keys=info.keys(), + names=["mouse", "session", "unit_idx"], + ) + + return info_per_mouse, info_pooled + def get_spike_widths(self, units=None): """ Get the widths of spikes for units matching the specified criteria. @@ -301,23 +331,39 @@ def get_spike_waveforms(self, units=None): def get_waveform_metrics(self, units=None): """ - Get waveform metrics of mean waveform for units matching the specified criteria. + Get waveform metrics of mean waveform for units matching the specified + criteria; separated by mouse. """ waveform_metrics = {} - for i, session in enumerate(self.sessions): - if units: - if units[i]: - waveform_metrics[i] = session.get_waveform_metrics(units=units[i]) - else: - waveform_metrics[i] = session.get_waveform_metrics() + for m, mouse in enumerate(self.mouse_ids): + waveform_metrics[mouse] = {} + for i, session in enumerate(self.sessions): + if mouse in session.name: + if units: + if units[i]: + waveform_metrics[mouse][i] = session.get_waveform_metrics(units=units[i]) + else: + waveform_metrics[mouse][i] = session.get_waveform_metrics() + + long_df = pd.concat( + waveform_metrics[mouse], + axis=0, + names=["session", "unit_idx"], + ) + # drop nan rows + long_df.dropna(inplace=True) + waveform_metrics[mouse] = long_df - df = pd.concat( - waveform_metrics.values(), axis=1, copy=False, + waveform_metrics_per_mouse = waveform_metrics + + waveform_metrics_pooled = pd.concat( + waveform_metrics, axis=0, copy=False, keys=waveform_metrics.keys(), - names=["session"] + names=["mouse", "session", "unit_idx"], ) - return df + + return waveform_metrics_per_mouse, waveform_metrics_pooled def get_aligned_spike_rate_CI(self, *args, units=None, **kwargs): """ From 4a88ef9e5343f12e474641313f1abaa8bf4887b6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 25 Jan 2023 18:35:04 +0000 Subject: [PATCH 028/658] make sure session id resets for each mouse --- pixels/experiment.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 2585749..710feaf 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -268,9 +268,13 @@ def get_good_units_info(self): for m, mouse in enumerate(self.mouse_ids): info[mouse] = {} - for i, session in enumerate(self.sessions): + mouse_sessions = [] + for session in self.sessions: if mouse in session.name: - info[mouse][i] = session.get_good_units_info() + mouse_sessions.append(session) + + for i, session in enumerate(mouse_sessions): + info[mouse][i] = session.get_good_units_info() long_df = pd.concat( info[mouse], @@ -337,14 +341,18 @@ def get_waveform_metrics(self, units=None): waveform_metrics = {} for m, mouse in enumerate(self.mouse_ids): + mouse_sessions = [] waveform_metrics[mouse] = {} - for i, session in enumerate(self.sessions): + for session in self.sessions: if mouse in session.name: - if units: - if units[i]: - waveform_metrics[mouse][i] = session.get_waveform_metrics(units=units[i]) - else: - waveform_metrics[mouse][i] = session.get_waveform_metrics() + mouse_sessions.append(session) + + for i, session in enumerate(mouse_sessions): + if units: + if units[i]: + waveform_metrics[mouse][i] = session.get_waveform_metrics(units=units[i]) + else: + waveform_metrics[mouse][i] = session.get_waveform_metrics() long_df = pd.concat( waveform_metrics[mouse], From 3d76332b2baab9cbf820de59633e7654e51269a1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 25 Jan 2023 18:36:26 +0000 Subject: [PATCH 029/658] normalise average waveform for a more 'accurate' waveform metrics --- pixels/behaviours/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1ec6e32..0cf21b3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2098,7 +2098,7 @@ def get_spike_waveforms(self, units=None, method='phy'): @_cacheable - def get_waveform_metrics(self, units=None, window=20): + def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): """ This func is a work-around of spikeinterface's equivalent: https://github.com/SpikeInterface/spikeinterface/blob/master/spikeinterface/postprocessing/template_metrics.py. @@ -2130,11 +2130,12 @@ def get_waveform_metrics(self, units=None, window=20): output = {} for i, unit in enumerate(units): metrics = [] - mean_waveform = waveforms[unit].median(axis=1) + #mean_waveform = waveforms[unit].mean(axis=1) + median_waveform = waveforms[unit].median(axis=1) # normalise mean waveform to remove variance caused by distance! - mean_waveform = mean_waveform / mean_waveform.abs().max() + norm_waveform = median_waveform / median_waveform.abs().max() #TODO: test! also can try clustering on normalised meann waveform - assert 0 + mean_waveform = norm_waveform # time between trough to peak, in ms trough_idx = np.argmin(mean_waveform) @@ -2201,6 +2202,8 @@ def get_waveform_metrics(self, units=None, window=20): df.columns = columns dtype = {"unit": int} df = df.astype(dtype) + # see which cols have nan + df.isnull().sum() return df From 383f162172ff7b8c90fc15e4006afd2146c521a2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 25 Jan 2023 18:37:41 +0000 Subject: [PATCH 030/658] default of resample --- pixels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/signal.py b/pixels/signal.py index 6451c92..9e2c36f 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -33,7 +33,7 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): poly : bool, choose the resample function. If True, use scipy.signal.resample_poly; if False, use scipy.signal.resample. - Default is False. + Default is True. lfp downsampling only works if using scipy.signal.resample. Returns From 91aab710fd67493a2c30c46167b36f72abd05f68 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 27 Jan 2023 19:55:33 +0000 Subject: [PATCH 031/658] add annotation --- pixels/behaviours/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0cf21b3..908cbba 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1928,7 +1928,7 @@ def align_clips(self, label, event, video_match, duration=1): def get_cluster_info(self): if self._cluster_info is None: info_file = self.ks_output / 'cluster_info.tsv' - print(f"> got cluster info at {info_file}\n") + #print(f"> got cluster info at {info_file}\n") try: info = pd.read_csv(info_file, sep='\t') @@ -1940,8 +1940,9 @@ def get_cluster_info(self): def get_good_units_info(self): if self._good_unit_info is None: + #az: good_units_info.tsv saved while running depth_profile.py info_file = self.interim / 'good_units_info.tsv' - print(f"> got good unit info at {info_file}\n") + #print(f"> got good unit info at {info_file}\n") try: info = pd.read_csv(info_file, sep='\t') From 3f7f7ee457c93c071c935ebee069d50a2f48ebaf Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 8 Feb 2023 18:56:23 +0000 Subject: [PATCH 032/658] fix typos; try adding sync edges file made by catgt; start working on sync npx streams --- pixels/behaviours/base.py | 32 ++++++++++++++++++++++++++++++++ pixels/ioutils.py | 12 ++++++++++-- pixels/signal.py | 2 +- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 908cbba..7342e04 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -363,6 +363,35 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): with (self.processed / 'lag.json').open('w') as fd: json.dump(lag_json, fd) + def sync_streams(self): + """ + Neuropixels data streams acquired simultaneously are not synchronised, unless + they are plugged into the same headstage, which is only the case for + Neuropixels 2.0 probes. + Dual-recording data acquired by Neuropixels 1.0 probes needs to be + synchronised. + Specifically, spike times from imecX stream need to be re-mapped to imec0 + time scale. + + For more info, see + https://open-ephys.github.io/gui-docs/Tutorials/Data-Synchronization.html and + https://github.com/billkarsh/TPrime. + """ + edges = [] + for rec_num, recording in enumerate(self.files): + # get file names and stuff + spike_data = self.find_file(recording['spike_data']) + spike_meta = self.find_file(files['spike_meta']) + stream_id = spike_data.as_posix()[-12:-4] + self.gate_idx = spike_data.as_posix()[-18:-16] + self.trigger_idx = spike_data.as_posix()[-15:-13] + + # find extracted rising sync edges, rn from CatGT + CatGT_output = self.interim / ('catgt_' + self.name + self.gate_idx) + edges + if stream_id not in streams: + streams[stream_id] = metadata + def process_behaviour(self): """ Process behavioural data from raw tdms and align to neuropixels data. @@ -568,6 +597,9 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: def load_recording(self): + """ + Write a function to load recording. + """ try: recording = si.load_extractor(self.interim / 'cache/recording.json') concat_rec = recording diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 5b43884..36106b5 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -117,8 +117,16 @@ def get_data_files(data_dir, session_name): recording['depth_info'] = recording['lfp_data'].with_name( f'depth_info_{num}.json' ) - recording['catGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") - recording['catGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") + recording['CatGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") + recording['CatGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") + recording['CatGT_ap_sync'] = str(recording['spike_meta']).replace("t0", + "tcat").replace("meta", "xd_384") + recording['CatGT_ap_sync'] = sorted(glob.glob( + rf'{data_dir}/catgt_{session_name}_g[0-9]' + '/*xd*.txt', recursive=True) + ) + #TODO: does this way of finding sync edges file works? + assert 0 + files.append(recording) return files diff --git a/pixels/signal.py b/pixels/signal.py index 9e2c36f..61254de 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -134,7 +134,7 @@ def _binarise_real(data): def find_sync_lag(array1, array2, plot=False): """ Find the lag between two arrays where they have the greatest number of the same - values. This functions assumes that the lag is less than 120,000 points. + values. This functions assumes that the lag is less than 300,000 points. Parameters ---------- From 3ff8c62cd0e98c5268c36cd51de43132893d8690 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Feb 2023 15:44:21 +0000 Subject: [PATCH 033/658] optimise ks_output definition; incorporate multiple streams of neuropixels data --- pixels/behaviours/base.py | 156 +++++++++++++++++++++++--------------- 1 file changed, 93 insertions(+), 63 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7342e04..d456b43 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -139,33 +139,32 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self.raw = self.data_dir / 'raw' / self.name self.processed = self.data_dir / 'processed' / self.name - ks_output = glob.glob( + self.files = ioutils.get_data_files(self.raw, name) + + self.ks_outputs = sorted(glob.glob( str(self.processed) +'/' + f'sorted_stream_cat_[0-9]' - ) - if not len(ks_output) == 0: - ks_output = Path(ks_output[0]) - if not ((ks_output / 'phy_ks3').exists() and - len(os.listdir(ks_output / 'phy_ks3'))>17): - #if not (ks_output.exists() and - #len(os.listdir(ks_output / ks_output))>17): - self.ks_output = ks_output + )) + for i, output in enumerate(self.ks_outputs): + if not len(self.ks_outputs) == 0: + output = Path(output) + if not ((output / 'phy_ks3').exists() and + len(os.listdir(output / 'phy_ks3'))>17): + self.ks_outputs[i] = output + else: + self.ks_outputs[i] = output / 'phy_ks3' else: - self.ks_output = ks_output / 'phy_ks3' - else: - self.ks_output = Path(glob.glob( - str(self.processed) +'/' + f'sorted_stream_[0-9]' - )[0]) - - self.files = ioutils.get_data_files(self.raw, name) + output = sorted(glob.glob( + str(self.processed) +'/' + f'sorted_stream_[0-9]' + )) if interim_dir is None: self.interim = self.data_dir / 'interim' / self.name else: self.interim = Path(interim_dir) / self.name - self.CatGT_dir = glob.glob( + self.CatGT_dir = sorted(glob.glob( str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' - ) + )) self.interim.mkdir(parents=True, exist_ok=True) self.processed.mkdir(parents=True, exist_ok=True) @@ -363,7 +362,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): with (self.processed / 'lag.json').open('w') as fd: json.dump(lag_json, fd) - def sync_streams(self): + def sync_streams(self, SYNC_BIN): """ Neuropixels data streams acquired simultaneously are not synchronised, unless they are plugged into the same headstage, which is only the case for @@ -376,21 +375,52 @@ def sync_streams(self): For more info, see https://open-ephys.github.io/gui-docs/Tutorials/Data-Synchronization.html and https://github.com/billkarsh/TPrime. + + params + === + SYNC_BIN: int, number of rising sync edges for calculating the scaling + factor. """ - edges = [] + edges_list = [] + stream_ids = [] + self.CatGT_dir = Path(self.CatGT_dir[0]) + for rec_num, recording in enumerate(self.files): # get file names and stuff - spike_data = self.find_file(recording['spike_data']) - spike_meta = self.find_file(files['spike_meta']) + spike_data = self.CatGT_dir / recording['CatGT_ap_data'] + spike_meta = self.CatGT_dir / recording['CatGT_ap_meta'] stream_id = spike_data.as_posix()[-12:-4] - self.gate_idx = spike_data.as_posix()[-18:-16] + stream_ids.append(stream_id) + #self.gate_idx = spike_data.as_posix()[-18:-16] self.trigger_idx = spike_data.as_posix()[-15:-13] # find extracted rising sync edges, rn from CatGT - CatGT_output = self.interim / ('catgt_' + self.name + self.gate_idx) - edges - if stream_id not in streams: - streams[stream_id] = metadata + try: + edges_file = sorted(glob.glob( + rf'{self.CatGT_dir}' + f'/*{stream_id}.xd*.txt', recursive=True) + )[0] + except IndexError as e: + raise PixelsError( + f"Can't load sync pulse rising edges. Did you run CatGT and\ + extract edges? Full error: {e}\n" + ) + # read sync edges + edges = np.loadtxt(edges_file) + edges_list.append(edges) + + # load spike times of the last recording + self._spike_times_data = self._get_spike_times() + #TODO: times? + + # make list np array and calculate difference between streams + edges = np.array(edges_list) + diff = np.diff(edges, axis=0).squeeze() + lag = [None, 'earlier', 'later'] + print(f"""\n{stream_ids[0]} started {abs(diff[0]*1000):.2f}ms\r + {lag[int(np.sign(diff[0]))]} than {stream_ids[1]}.""") + + assert 0 + def process_behaviour(self): """ @@ -612,8 +642,8 @@ def load_recording(self): print("\n> Getting catgt-ed recording...") self.CatGT_dir = Path(self.CatGT_dir[0]) data_dir = self.CatGT_dir - data_file = data_dir / files['catGT_ap_data'] - metadata = data_dir / files['catGT_ap_meta'] + data_file = data_dir / files['CatGT_ap_data'] + metadata = data_dir / files['CatGT_ap_meta'] except: print(f"\n> Getting the orignial recording...") data_file = self.find_file(files['spike_data']) @@ -674,8 +704,8 @@ def sort_spikes(self, CatGT_app=None, old=False): print("\n> Sorting catgt-ed spikes\n") self.CatGT_dir = Path(self.CatGT_dir[0]) - data_file = self.CatGT_dir / files['catGT_ap_data'] - metadata = self.CatGT_dir / files['catGT_ap_meta'] + data_file = self.CatGT_dir / files['CatGT_ap_data'] + metadata = self.CatGT_dir / files['CatGT_ap_meta'] else: print(f"\n> using the orignial spike data.\n") data_file = self.find_file(files['spike_data']) @@ -1443,45 +1473,45 @@ def get_lfp_data(self): """ return self._get_processed_data("_lfp_data", "lfp_processed") - def _get_spike_times(self, catgt=False): + def _get_spike_times(self): """ Returns the sorted spike times. """ - saved = self._spike_times_data - if saved[0] is None: - # TODO: temporarily add catgt arg here, - if catgt: - stream = 'sorted_stream_cat_0' - else: - stream = 'sorted_stream_0' - times = self.processed / stream / f'spike_times.npy' - clust = self.processed / stream / f'spike_clusters.npy' + spike_times = self._spike_times_data - try: - times = np.load(times) - clust = np.load(clust) - except FileNotFoundError: - msg = ": Can't load spike times that haven't been extracted!" - raise PixelsError(self.name + msg) - - times = np.squeeze(times) - clust = np.squeeze(clust) - by_clust = {} + if spike_times[0] is None: + for i in range(len(spike_times)): + times = self.ks_output / f'spike_times.npy' + clust = self.ks_output / f'spike_clusters.npy' + assert 0 - for c in np.unique(clust): - c_times = times[clust == c] - uniques, counts = np.unique( - c_times, - return_counts=True, - ) - repeats = c_times[np.where(counts>1)] - if len(repeats>1): - print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + try: + times = np.load(times) + clust = np.load(clust) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) + + times = np.squeeze(times) + clust = np.squeeze(clust) + by_clust = {} + + for c in np.unique(clust): + c_times = times[clust == c] + uniques, counts = np.unique( + c_times, + return_counts=True, + ) + repeats = c_times[np.where(counts>1)] + if len(repeats>1): + print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") - by_clust[c] = pd.Series(uniques) - saved[0] = pd.concat(by_clust, axis=1, names=['unit']) + by_clust[c] = pd.Series(uniques) + spike_times[0] = pd.concat(by_clust, axis=1, names=['unit']) + else: + print("new stream?") assert 0 - return saved[0] + return spike_times[0] def _get_aligned_spike_times( self, label, event, duration, rate=False, sigma=None, units=None From e4fb312cbea0e2af235edb7dd39acb094eac7adc Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Feb 2023 15:55:48 +0000 Subject: [PATCH 034/658] do not add sync edges files, find them only when needed --- pixels/ioutils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 36106b5..ee7e3a1 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -119,13 +119,6 @@ def get_data_files(data_dir, session_name): ) recording['CatGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") recording['CatGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") - recording['CatGT_ap_sync'] = str(recording['spike_meta']).replace("t0", - "tcat").replace("meta", "xd_384") - recording['CatGT_ap_sync'] = sorted(glob.glob( - rf'{data_dir}/catgt_{session_name}_g[0-9]' + '/*xd*.txt', recursive=True) - ) - #TODO: does this way of finding sync edges file works? - assert 0 files.append(recording) From 7e447317248e4643d203ccc7af31166ff3195a1c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Feb 2023 18:28:26 +0000 Subject: [PATCH 035/658] use imec0 stream as master-clock and sync other pixels data streams to it; load remapped spike times if available --- pixels/behaviours/base.py | 169 +++++++++++++++++++++++++------------- 1 file changed, 113 insertions(+), 56 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d456b43..17d11fc 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -144,16 +144,16 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self.ks_outputs = sorted(glob.glob( str(self.processed) +'/' + f'sorted_stream_cat_[0-9]' )) - for i, output in enumerate(self.ks_outputs): + for stream_num, stream in enumerate(self.ks_outputs): if not len(self.ks_outputs) == 0: - output = Path(output) - if not ((output / 'phy_ks3').exists() and - len(os.listdir(output / 'phy_ks3'))>17): - self.ks_outputs[i] = output + stream = Path(stream) + if not ((stream / 'phy_ks3').exists() and + len(os.listdir(stream / 'phy_ks3'))>17): + self.ks_outputs[stream_num] = stream else: - self.ks_outputs[i] = output / 'phy_ks3' + self.ks_outputs[stream_num] = stream / 'phy_ks3' else: - output = sorted(glob.glob( + stream = sorted(glob.glob( str(self.processed) +'/' + f'sorted_stream_[0-9]' )) @@ -362,7 +362,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): with (self.processed / 'lag.json').open('w') as fd: json.dump(lag_json, fd) - def sync_streams(self, SYNC_BIN): + def sync_streams(self, SYNC_BIN, remap_stream_idx): """ Neuropixels data streams acquired simultaneously are not synchronised, unless they are plugged into the same headstage, which is only the case for @@ -384,6 +384,24 @@ def sync_streams(self, SYNC_BIN): edges_list = [] stream_ids = [] self.CatGT_dir = Path(self.CatGT_dir[0]) + output = self.ks_outputs[remap_stream_idx] / f'spike_times_remapped.npy' + + if output.exists(): + print(f'\n> Spike times from {self.ks_outputs[remap_stream_idx]}\ + already remapped, next session.') + return + + # load spike times that needs to be remapped + times = self.ks_outputs[remap_stream_idx] / f'spike_times.npy' + try: + times = np.load(times) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) + times = np.squeeze(times) + # convert spike times to ms + orig_rate = int(self.spike_meta[0]['imSampRate']) + times_ms = times * self.sample_rate / orig_rate for rec_num, recording in enumerate(self.files): # get file names and stuff @@ -404,22 +422,50 @@ def sync_streams(self, SYNC_BIN): f"Can't load sync pulse rising edges. Did you run CatGT and\ extract edges? Full error: {e}\n" ) - # read sync edges + # read sync edges, ms edges = np.loadtxt(edges_file) - edges_list.append(edges) - - # load spike times of the last recording - self._spike_times_data = self._get_spike_times() - #TODO: times? + # pick edges by defined bin + binned_edges = np.append( + arr=edges[0::SYNC_BIN], + values=edges[-1], + ) + edges_list.append(binned_edges) - # make list np array and calculate difference between streams + # make list np array and calculate difference between streams to get the + # initial difference edges = np.array(edges_list) - diff = np.diff(edges, axis=0).squeeze() - lag = [None, 'earlier', 'later'] - print(f"""\n{stream_ids[0]} started {abs(diff[0]*1000):.2f}ms\r - {lag[int(np.sign(diff[0]))]} than {stream_ids[1]}.""") - - assert 0 + initial_dt = np.diff(edges, axis=0).squeeze() + lag = [None, 'later', 'earlier'] + print(f"""\n> {stream_ids[0]} started\r + {abs(initial_dt[0]*1000):.2f}ms {lag[int(np.sign(initial_dt[0]))]}\r + and finished\r + {abs(initial_dt[-1]*1000):.2f}ms {lag[-int(np.sign(initial_dt[0]))]}\r + than {stream_ids[1]}.""") + + # create a np array for remapped spike times from imec1 + remapped_times_ms = np.zeros(times.shape) + + # get edge difference within & between streams + within_streams = np.diff(edges) + # calculate scaling factor for each sync bin + scales = within_streams[0] / within_streams[-1] + + # find out which sync period/bin each spike time belongs to, and use + # the scale from that period/bin + for t, time in enumerate(times_ms): + bin_idx = np.where(time > edges[remap_stream_idx])[0][0] + remapped_times_ms[t] = ((time - edges[remap_stream_idx][bin_idx]) * + scales[bin_idx]) + edges[0][bin_idx] + + print(f"""\n> Remap stats {stream_ids[remap_stream_idx]} spike times:\r + median shift {np.median(remapped_times_ms-times_ms):.2f}ms,\r + min shift {np.min(remapped_times_ms-times_ms):.2f}ms,\r + max shift {np.max(remapped_times_ms-times_ms):.2f}ms.""") + + # convert remappmed times back to its original sample index + remapped_times = np.uint64(remapped_times_ms * orig_rate / self.sample_rate) + np.save(output, remapped_times) + print(f'\n> Spike times remapping output saved to\n {output}.') def process_behaviour(self): @@ -1479,39 +1525,40 @@ def _get_spike_times(self): """ spike_times = self._spike_times_data - if spike_times[0] is None: - for i in range(len(spike_times)): - times = self.ks_output / f'spike_times.npy' - clust = self.ks_output / f'spike_clusters.npy' - assert 0 + for stream_num, stream in enumerate(range(len(spike_times))): + try: + times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' + print(f'\n> Found remapped spike times from\r + {self.ks_outputs[stream_num]}, try to load this.') + except: + times = self.ks_outputs[stream_num] / f'spike_times.npy' + clust = self.ks_outputs[stream_num] / f'spike_clusters.npy' - try: - times = np.load(times) - clust = np.load(clust) - except FileNotFoundError: - msg = ": Can't load spike times that haven't been extracted!" - raise PixelsError(self.name + msg) - - times = np.squeeze(times) - clust = np.squeeze(clust) - by_clust = {} - - for c in np.unique(clust): - c_times = times[clust == c] - uniques, counts = np.unique( - c_times, - return_counts=True, - ) - repeats = c_times[np.where(counts>1)] - if len(repeats>1): - print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + try: + times = np.load(times) + clust = np.load(clust) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) - by_clust[c] = pd.Series(uniques) - spike_times[0] = pd.concat(by_clust, axis=1, names=['unit']) - else: - print("new stream?") - assert 0 - return spike_times[0] + times = np.squeeze(times) + clust = np.squeeze(clust) + by_clust = {} + + for c in np.unique(clust): + c_times = times[clust == c] + uniques, counts = np.unique( + c_times, + return_counts=True, + ) + repeats = c_times[np.where(counts>1)] + if len(repeats>1): + print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + + by_clust[c] = pd.Series(uniques) + spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) + + return spike_times def _get_aligned_spike_times( self, label, event, duration, rate=False, sigma=None, units=None @@ -1526,6 +1573,8 @@ def _get_aligned_spike_times( if units is None: units = self.select_units() + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure old code does not break! spikes = self._get_spike_times()[units] # Convert to ms (self.sample_rate) spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate @@ -1989,7 +2038,9 @@ def align_clips(self, label, event, video_match, duration=1): def get_cluster_info(self): if self._cluster_info is None: - info_file = self.ks_output / 'cluster_info.tsv' + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! + info_file = self.ks_outputs / 'cluster_info.tsv' #print(f"> got cluster info at {info_file}\n") try: @@ -2066,7 +2117,9 @@ def get_spike_waveforms(self, units=None, method='phy'): units = self.select_units() #paramspy = self.processed / 'sorted_stream_0' / 'params.py' - paramspy = self.ks_output / 'params.py' + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! + paramspy = self.ks_outputs / 'params.py' if not paramspy.exists(): raise PixelsError(f"{self.name}: params.py not found") model = load_model(paramspy) @@ -2103,8 +2156,10 @@ def get_spike_waveforms(self, units=None, method='phy'): progress_bar=True, ) recording, _ = self.load_recording() + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! try: - sorting = se.read_kilosort(self.ks_output) + sorting = se.read_kilosort(self.ks_outputs) except ValueError as e: raise PixelsError( f"Can't load sorting object. Did you delete cluster_info.csv? Full error: {e}\n" @@ -2114,7 +2169,9 @@ def get_spike_waveforms(self, units=None, method='phy'): try: template_cache_mod_time = os.path.getmtime(self.interim / 'cache/templates_average.npy') - ks_mod_time = os.path.getmtime(self.ks_output / 'cluster_info.tsv') + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure put old code under loop of stream so it does not break! + ks_mod_time = os.path.getmtime(self.ks_outputs / 'cluster_info.tsv') assert template_cache_mod_time < ks_mod_time check = True # re-extract waveforms print("> Re-extracting waveforms since kilosort output is newer.") From f7feb1a0670d2cf6b414e4d4d9c5c2297fc9a576 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 10 Feb 2023 18:46:02 +0000 Subject: [PATCH 036/658] do not remap if not necessary; return remapped times for plotting; save sync pulse differences to processed dir --- pixels/behaviours/base.py | 66 +++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 17d11fc..5e06179 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -386,22 +386,20 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): self.CatGT_dir = Path(self.CatGT_dir[0]) output = self.ks_outputs[remap_stream_idx] / f'spike_times_remapped.npy' + # do not redo the remapping if not necessary if output.exists(): print(f'\n> Spike times from {self.ks_outputs[remap_stream_idx]}\ already remapped, next session.') - return + cluster_times = self._get_spike_times()[remap_stream_idx] + remapped_cluster_times = self._get_spike_times( + remapped=True)[remap_stream_idx] - # load spike times that needs to be remapped - times = self.ks_outputs[remap_stream_idx] / f'spike_times.npy' - try: - times = np.load(times) - except FileNotFoundError: - msg = ": Can't load spike times that haven't been extracted!" - raise PixelsError(self.name + msg) - times = np.squeeze(times) - # convert spike times to ms - orig_rate = int(self.spike_meta[0]['imSampRate']) - times_ms = times * self.sample_rate / orig_rate + # get first spike time from each cluster, and their difference + clusters_first = cluster_times.iloc[0,:] + remapped_clusters_first = remapped_cluster_times.iloc[0,:] + remap = remapped_clusters_first - clusters_first + + return remap for rec_num, recording in enumerate(self.files): # get file names and stuff @@ -435,6 +433,22 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): # initial difference edges = np.array(edges_list) initial_dt = np.diff(edges, axis=0).squeeze() + # save initial diff for later plotting + np.save(self.processed / 'sync_streams_lag.npy', initial_dt) + + # load spike times that needs to be remapped + times = self.ks_outputs[remap_stream_idx] / f'spike_times.npy' + try: + times = np.load(times) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) + times = np.squeeze(times) + + # convert spike times to ms + orig_rate = int(self.spike_meta[0]['imSampRate']) + times_ms = times * self.sample_rate / orig_rate + lag = [None, 'later', 'earlier'] print(f"""\n> {stream_ids[0]} started\r {abs(initial_dt[0]*1000):.2f}ms {lag[int(np.sign(initial_dt[0]))]}\r @@ -467,6 +481,18 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): np.save(output, remapped_times) print(f'\n> Spike times remapping output saved to\n {output}.') + # load remapped spike times of each cluster + cluster_times = self._get_spike_times()[remap_stream_idx] + remapped_cluster_times = self._get_spike_times( + remapped=True)[remap_stream_idx] + + # get first spike time from each cluster, and their difference + clusters_first = cluster_times.iloc[0,:] + remapped_clusters_first = remapped_cluster_times.iloc[0,:] + remap = remapped_clusters_first - clusters_first + + return remap + def process_behaviour(self): """ @@ -1519,19 +1545,25 @@ def get_lfp_data(self): """ return self._get_processed_data("_lfp_data", "lfp_processed") - def _get_spike_times(self): + def _get_spike_times(self, remapped=False): """ Returns the sorted spike times. + + params + === + remapped: bool, if using remapped (synced with imec0) spike times. + Default: False """ spike_times = self._spike_times_data for stream_num, stream in enumerate(range(len(spike_times))): - try: + if remapped and stream_num > 0: times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' - print(f'\n> Found remapped spike times from\r - {self.ks_outputs[stream_num]}, try to load this.') - except: + print(f"""\n> Found remapped spike times from\r + {self.ks_outputs[stream_num]}, try to load this.""") + else: times = self.ks_outputs[stream_num] / f'spike_times.npy' + clust = self.ks_outputs[stream_num] / f'spike_clusters.npy' try: From 40bf56377f9ec35b622f8f91f669e67978262ed6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Mar 2023 16:59:43 +0000 Subject: [PATCH 037/658] make sure probe depth & cluster info works with dual-probe recording --- pixels/behaviours/base.py | 143 ++++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 68 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 5e06179..806bab1 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -141,21 +141,23 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self.files = ioutils.get_data_files(self.raw, name) - self.ks_outputs = sorted(glob.glob( - str(self.processed) +'/' + f'sorted_stream_cat_[0-9]' + ks_outputs = sorted(glob.glob( + str(self.processed) +'/' + f'sorted_stream_*' )) - for stream_num, stream in enumerate(self.ks_outputs): - if not len(self.ks_outputs) == 0: - stream = Path(stream) - if not ((stream / 'phy_ks3').exists() and - len(os.listdir(stream / 'phy_ks3'))>17): - self.ks_outputs[stream_num] = stream + self.ks_outputs = [None] * len(ks_outputs) + if not len(ks_outputs) == 0: + for stream_num, stream in enumerate(ks_outputs): + path = Path(stream) + if stream.split('_')[-2] == 'cat': + if not ((path / 'phy_ks3').exists() and + len(os.listdir(path / 'phy_ks3'))>17): + self.ks_outputs[stream_num] = path + else: + self.ks_outputs[stream_num] = path / 'phy_ks3' else: - self.ks_outputs[stream_num] = stream / 'phy_ks3' - else: - stream = sorted(glob.glob( - str(self.processed) +'/' + f'sorted_stream_[0-9]' - )) + self.ks_outputs[stream_num] = path + else: + print(f"\n> {self.name} have not been spike-sorted.") if interim_dir is None: self.interim = self.data_dir / 'interim' / self.name @@ -178,14 +180,15 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self._use_cache = True self._cluster_info = None self._good_unit_info = None + self._probe_depths = None self.drop_data() self.spike_meta = [ ioutils.read_meta(self.find_file(f['spike_meta'], copy=False)) for f in self.files ] - self.lfp_meta = [ - ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files - ] + #self.lfp_meta = [ + # ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files + #] # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache self.set_cache(bool(int(os.environ.get("PIXELS_CACHE", 1)))) @@ -201,6 +204,8 @@ def drop_data(self): self._spike_rate_data = [None] * len(self.files) self._lfp_data = [None] * len(self.files) self._motion_index = [None] * len(self.files) + self._cluster_info = [None] * len(self.files) + self._probe_depths = [None] * len(self.files) self._load_lag() def set_cache(self, on: bool | Literal["overwrite"]) -> None: @@ -230,20 +235,26 @@ def get_probe_depth(self): """ Load probe depth in um from file if it has been recorded. """ - depth_file = self.processed / 'depth.txt' - if not depth_file.exists(): - depth_file = self.processed / self.files[0]["depth_info"] - - if not depth_file.exists(): - msg = f": Can't load probe depth: please add it in um to\ - \nprocessed/{self.name}/depth.txt, or save it with other depth related\ - \ninfo in {self.processed / self.files[0]['depth_info']}." - raise PixelsError(msg) - if Path(depth_file).suffix == ".txt": - with depth_file.open() as fd: - return [float(line) for line in fd.readlines()] - elif Path(depth_file).suffix == ".json": - return [json.load(open(depth_file, mode="r"))["clustering"]] + for stream_num, depth in enumerate(self._probe_depths): + if depth is None: + try: + depth_file = self.processed / 'depth.txt' + with depth_file.open() as fd: + self._probe_depths[stream_num] = [float(line) for line in + fd.readlines()][0] + except: + depth_file = self.processed / self.files[stream_num]["depth_info"] + self._probe_depths[stream_num] = json.load(open(depth_file, mode="r"))["clustering"] + else: + msg = f": Can't load probe depth: please add it in um to\ + \nprocessed/{self.name}/depth.txt, or save it with other depth related\ + \ninfo in {self.processed / self.files[0]['depth_info']}." + raise PixelsError(msg) + + #if Path(depth_file).suffix == ".txt": + #elif Path(depth_file).suffix == ".json": + # return [json.load(open(depth_file, mode="r"))["clustering"]] + return self._probe_depths def find_file(self, name: str, copy: bool=True) -> Optional[Path]: """ @@ -1738,7 +1749,7 @@ def select_units( selected_units.name = name if min_depth is not None or max_depth is not None: - probe_depth = self.get_probe_depth()[0] + probe_depths = self.get_probe_depth() if min_spike_width == 0: min_spike_width = None @@ -1747,37 +1758,36 @@ def select_units( else: widths = None - rec_num = 0 - - id_key = 'id' if 'id' in cluster_info else 'cluster_id' - grouping = 'KSLabel' if uncurated else 'group' + for stream_num, info in enumerate(cluster_info): + id_key = 'id' if 'id' in info else 'cluster_id' + grouping = 'KSLabel' if uncurated else 'group' - for unit in cluster_info[id_key]: - unit_info = cluster_info.loc[cluster_info[id_key] == unit].iloc[0].to_dict() + for unit in info[id_key]: + unit_info = info.loc[info[id_key] == unit].iloc[0].to_dict() - # we only want units that are in the specified group - if not group or unit_info[grouping] == group: + # we only want units that are in the specified group + if not group or unit_info[grouping] == group: - # and that are within the specified depth range - if min_depth is not None: - if probe_depth - unit_info['depth'] <= min_depth: - continue - if max_depth is not None: - if probe_depth - unit_info['depth'] > max_depth: - continue - - # and that have the specified median spike widths - if widths is not None: - width = widths[widths['unit'] == unit]['median_ms'] - assert len(width.values) == 1 - if min_spike_width is not None: - if width.values[0] < min_spike_width: + # and that are within the specified depth range + if min_depth is not None: + if probe_depths[stream_num] - unit_info['depth'] <= min_depth: continue - if max_spike_width is not None: - if width.values[0] > max_spike_width: + if max_depth is not None: + if probe_depths[stream_num] - unit_info['depth'] > max_depth: continue - selected_units.append(unit) + # and that have the specified median spike widths + if widths is not None: + width = widths[widths['unit'] == unit]['median_ms'] + assert len(width.values) == 1 + if min_spike_width is not None: + if width.values[0] < min_spike_width: + continue + if max_spike_width is not None: + if width.values[0] > max_spike_width: + continue + + selected_units.append(unit) return selected_units @@ -2069,18 +2079,15 @@ def align_clips(self, label, event, video_match, duration=1): return trials def get_cluster_info(self): - if self._cluster_info is None: - #TODO: with multiple streams, spike times will be a list with multiple dfs, - #make sure put old code under loop of stream so it does not break! - info_file = self.ks_outputs / 'cluster_info.tsv' - #print(f"> got cluster info at {info_file}\n") - - try: - info = pd.read_csv(info_file, sep='\t') - except FileNotFoundError: - msg = ": Can't load cluster info. Did you sort this session yet?" - raise PixelsError(self.name + msg) - self._cluster_info = info + for stream_num, info in enumerate(self._cluster_info): + if info is None: + info_file = self.ks_outputs[stream_num] / 'cluster_info.tsv' + try: + info = pd.read_csv(info_file, sep='\t') + except FileNotFoundError: + msg = ": Can't load cluster info. Did you sort this session yet?" + raise PixelsError(self.name + msg) + self._cluster_info[stream_num] = info return self._cluster_info def get_good_units_info(self): From d5deea51143c795f68c00129c6d2157bd30db98e Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Mar 2023 17:07:43 +0000 Subject: [PATCH 038/658] do not error if no lfp data --- pixels/ioutils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index ee7e3a1..706b924 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -59,10 +59,10 @@ def get_data_files(data_dir, session_name): raise PixelsError(f"{session_name}: could not find raw AP data file.") if not spike_meta: raise PixelsError(f"{session_name}: could not find raw AP metadata file.") - if not lfp_data: - raise PixelsError(f"{session_name}: could not find raw LFP data file.") - if not lfp_meta: - raise PixelsError(f"{session_name}: could not find raw LFP metadata file.") + #if not lfp_data: + # raise PixelsError(f"{session_name}: could not find raw LFP data file.") + #if not lfp_meta: + # raise PixelsError(f"{session_name}: could not find raw LFP metadata file.") camera_data = [] camera_meta = [] @@ -78,8 +78,8 @@ def get_data_files(data_dir, session_name): recording = {} recording['spike_data'] = original_name(spike_recording) recording['spike_meta'] = original_name(spike_meta[num]) - recording['lfp_data'] = original_name(lfp_data[num]) - recording['lfp_meta'] = original_name(lfp_meta[num]) + #recording['lfp_data'] = original_name(lfp_data[num]) + #recording['lfp_meta'] = original_name(lfp_meta[num]) if behaviour: if len(behaviour) == len(spike_data): @@ -114,7 +114,7 @@ def get_data_files(data_dir, session_name): recording['clustered_channels'] = recording['lfp_data'].with_name( f'channel_clustering_results_{num}.h5' ) - recording['depth_info'] = recording['lfp_data'].with_name( + recording['depth_info'] = recording['spike_data'].with_name( f'depth_info_{num}.json' ) recording['CatGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") From 7ec46b27eeb38d3519fbfc2c1d276c130c5c1078 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 15 Mar 2023 21:52:44 +0000 Subject: [PATCH 039/658] get spike times --- pixels/experiment.py | 48 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 710feaf..ba3805c 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -283,7 +283,7 @@ def get_good_units_info(self): ) info[mouse] = long_df - info_per_mouse = info + mouse_info = info info_pooled = pd.concat( info, axis=0, copy=False, @@ -291,7 +291,7 @@ def get_good_units_info(self): names=["mouse", "session", "unit_idx"], ) - return info_per_mouse, info_pooled + return mouse_info, info_pooled def get_spike_widths(self, units=None): """ @@ -325,7 +325,10 @@ def get_spike_waveforms(self, units=None): waveforms[i] = session.get_spike_waveforms(units=units[i]) else: waveforms[i] = session.get_spike_waveforms() + assert 0 + df.add_prefix(i) + #TODO: get concat waveforms for each mouse df = pd.concat( waveforms.values(), axis=1, copy=False, keys=waveforms.keys(), @@ -363,7 +366,7 @@ def get_waveform_metrics(self, units=None): long_df.dropna(inplace=True) waveform_metrics[mouse] = long_df - waveform_metrics_per_mouse = waveform_metrics + mouse_waveform_metrics = waveform_metrics waveform_metrics_pooled = pd.concat( waveform_metrics, axis=0, copy=False, @@ -371,7 +374,44 @@ def get_waveform_metrics(self, units=None): names=["mouse", "session", "unit_idx"], ) - return waveform_metrics_per_mouse, waveform_metrics_pooled + return mouse_waveform_metrics, waveform_metrics_pooled + + def get_spike_times(self, units): + """ + Get spike times of each units, separated by mouse. + """ + spike_times = {} + + for m, mouse in enumerate(self.mouse_ids): + mouse_sessions = [] + spike_times[mouse] = {} + for session in self.sessions: + if mouse in session.name: + mouse_sessions.append(session) + + for i, session in enumerate(mouse_sessions): + spike_times[mouse][i] = session._get_spike_times()[units[i]] + #spike_times[mouse][i] = spike_times[mouse][i].add_prefix(f'{i}_') + + df = pd.concat( + spike_times[mouse], + axis=1, + names=["session", "unit"], + ) + spike_times[mouse] = df + + mouse_spike_times = spike_times + + """ + spike_times_pooled = pd.concat( + mouse_spike_times, axis=1, copy=False, + keys=mouse_spike_times.keys(), + names=["mouse", "session", "unit"], + ) + """ + + return mouse_spike_times + def get_aligned_spike_rate_CI(self, *args, units=None, **kwargs): """ From de33ba56821f207048246ee0b085b847f6d58f70 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 15 Mar 2023 21:53:40 +0000 Subject: [PATCH 040/658] index into spike time output to make sure of getting df and can be indexed into using unit id --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 806bab1..1504610 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1601,7 +1601,7 @@ def _get_spike_times(self, remapped=False): by_clust[c] = pd.Series(uniques) spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) - return spike_times + return spike_times[0] def _get_aligned_spike_times( self, label, event, duration, rate=False, sigma=None, units=None @@ -1618,6 +1618,7 @@ def _get_aligned_spike_times( #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! + #TODO: spike times cannot be indexed by unit ids anymore spikes = self._get_spike_times()[units] # Convert to ms (self.sample_rate) spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate From 027a28162b7f79187f0cf1687270e80b23ccd9bc Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 Jun 2023 16:57:30 +0100 Subject: [PATCH 041/658] add todo; now call trough-to-peak 'duration' to be consistent with jia et al --- pixels/behaviours/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1504610..84fddff 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2278,7 +2278,8 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): wave_metrics = self.get_waveform_metrics() return wave_metrics.loc[wave_metrics.unit.isin(units)] - columns = ["unit", "trough_to_peak", "trough_peak_ratio", "half_width", + # TODO june 2nd 2023: extract amplitude, i.e., abs(trough - peak) in mV + columns = ["unit", "duration", "trough_peak_ratio", "half_width", "repolarisation_slope", "recovery_slope"] print(f"> Calculating waveform metrics {columns[1:]}...\n") @@ -2304,8 +2305,8 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): raise PixelsError(f"> Cannot find peak in mean waveform.\n") if trough_idx == 0: raise PixelsError(f"> Cannot find trough in mean waveform.\n") - trough_to_peak = mean_waveform.index[peak_idx] - mean_waveform.index[trough_idx] - metrics.append(trough_to_peak) + duration = mean_waveform.index[peak_idx] - mean_waveform.index[trough_idx] + metrics.append(duration) # trough to peak ratio trough_peak_ratio = mean_waveform.iloc[peak_idx] / mean_waveform.iloc[trough_idx] From 9b52801c820cc8a11125f21dc1bd4eb68032f6a9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 Jun 2023 16:58:21 +0100 Subject: [PATCH 042/658] do not exclude lfp data --- pixels/ioutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 706b924..f5d9578 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -78,8 +78,8 @@ def get_data_files(data_dir, session_name): recording = {} recording['spike_data'] = original_name(spike_recording) recording['spike_meta'] = original_name(spike_meta[num]) - #recording['lfp_data'] = original_name(lfp_data[num]) - #recording['lfp_meta'] = original_name(lfp_meta[num]) + recording['lfp_data'] = original_name(lfp_data[num]) + recording['lfp_meta'] = original_name(lfp_meta[num]) if behaviour: if len(behaviour) == len(spike_data): From 3174e19a8baefea9ac0888fdd71feb334885276d Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 Jun 2023 17:23:45 +0100 Subject: [PATCH 043/658] add todos --- pixels/behaviours/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 84fddff..f9deb0f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2279,6 +2279,8 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): return wave_metrics.loc[wave_metrics.unit.isin(units)] # TODO june 2nd 2023: extract amplitude, i.e., abs(trough - peak) in mV + # make sure amplitude is in mV + # normalise these metrics before passing to k-means columns = ["unit", "duration", "trough_peak_ratio", "half_width", "repolarisation_slope", "recovery_slope"] print(f"> Calculating waveform metrics {columns[1:]}...\n") From 208cf669a7f7a9149d463ca9ff769a4ba92d3b80 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 4 Jan 2024 18:57:59 +0000 Subject: [PATCH 044/658] add todo; extract 350+-160ms sync edges --- pixels/behaviours/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f9deb0f..9dd58ba 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -699,7 +699,7 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: -lf\ -apfilter=butter,12,300,9000\ -lffilter=butter,12,0.5,300\ - -xd=2,0,384,6,500\ + -xd=2,0,384,6,350,160\ -gblcar\ -gfix=0.2,0.1,0.02" @@ -772,15 +772,15 @@ def sort_spikes(self, CatGT_app=None, old=False): streams = {} # set chunks for spikeinterface operations job_kwargs = dict( - n_jobs=10, # -1: num of job equals num of cores + n_jobs=20, # -1: num of job equals num of cores chunk_duration="1s", progress_bar=True, ) - concat_rec, output = self.load_recording() + #concat_rec, output = self.load_recording() + #assert 0 + #TODO: jan 3 see if ks can run normally now using load_recording() - assert 0 - #TODO: see if ks can run normally now using load_recording() for _, files in enumerate(self.files): if not CatGT_app == None: self.run_catgt(CatGT_app=CatGT_app) From 778710e8eac298d1575ad70bc7228accbcdff1ee Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 4 Jan 2024 19:11:33 +0000 Subject: [PATCH 045/658] make sure catgt only runs once if multiple streams in one session --- pixels/behaviours/base.py | 105 ++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 43 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9dd58ba..cd0e209 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -157,12 +157,12 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): else: self.ks_outputs[stream_num] = path else: - print(f"\n> {self.name} have not been spike-sorted.") + print(f"\n> {self.name} has not been spike-sorted.") if interim_dir is None: self.interim = self.data_dir / 'interim' / self.name else: - self.interim = Path(interim_dir) / self.name + self.interim = Path(interim_dir).expanduser() / self.name self.CatGT_dir = sorted(glob.glob( str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' @@ -184,7 +184,7 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): self.drop_data() self.spike_meta = [ - ioutils.read_meta(self.find_file(f['spike_meta'], copy=False)) for f in self.files + ioutils.read_meta(self.find_file(f['spike_meta'], copy=True)) for f in self.files ] #self.lfp_meta = [ # ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files @@ -394,7 +394,7 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): """ edges_list = [] stream_ids = [] - self.CatGT_dir = Path(self.CatGT_dir[0]) + #self.CatGT_dir = Path(self.CatGT_dir[0]) output = self.ks_outputs[remap_stream_idx] / f'spike_times_remapped.npy' # do not redo the remapping if not necessary @@ -414,8 +414,11 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): for rec_num, recording in enumerate(self.files): # get file names and stuff - spike_data = self.CatGT_dir / recording['CatGT_ap_data'] - spike_meta = self.CatGT_dir / recording['CatGT_ap_meta'] + # TODO jan 4 check if find_file works for catgt data + spike_data = self.find_file(recording['CatGT_ap_data']) + spike_meta = self.find_file(recording['CatGT_ap_meta']) + #spike_data = self.CatGT_dir / recording['CatGT_ap_data'] + #spike_meta = self.CatGT_dir / recording['CatGT_ap_meta'] stream_id = spike_data.as_posix()[-12:-4] stream_ids.append(stream_id) #self.gate_idx = spike_data.as_posix()[-18:-16] @@ -674,39 +677,46 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: # move cwd to catgt os.chdir(CatGT_app) - for rec_num, recording in enumerate(self.files): - # copy spike data to interim - self.find_file(recording['spike_data']) - - # reset catgt args for current session - session_args = None + # reset catgt args for current session + session_args = None - if len(self.CatGT_dir) != 0: - if len(os.listdir(self.CatGT_dir[0])) != 0: - print(f"\nCatGT already performed on ap data of {self.name}. Next session.\n") - continue + for f in self.files: + # copy spike data to interim + self.find_file(f['spike_data']) + if isinstance(self.CatGT_dir, list) and + len(self.CatGT_dir) != 0 and + len(os.listdir(self.CatGT_dir[0])) != 0: + print(f"\nCatGT already performed on ap data of {self.name}. Next session.\n") + return + else: #TODO: finish this here so that catgt can run together with sorting print(f"> Running CatGT on ap data of {self.name}") #_dir = self.interim - if args == None: - args = f"-no_run_fld\ - -g=0,9\ - -t=0,9\ - -prb=0\ - -ap\ - -lf\ - -apfilter=butter,12,300,9000\ - -lffilter=butter,12,0.5,300\ - -xd=2,0,384,6,350,160\ - -gblcar\ - -gfix=0.2,0.1,0.02" - - session_args = f"-dir={self.interim} -run={self.name} -dest={self.interim} " + args - print(f"\ncatgt args of {self.name}: \n{session_args}") - - subprocess.run( ['./run_catgt.sh', session_args]) + if args == None: + args = f"-no_run_fld\ + -g=0,9\ + -t=0,9\ + -prb=0:1\ + -prb_miss_ok\ + -ap\ + -lf\ + -apfilter=butter,12,300,9000\ + -lffilter=butter,12,0.5,300\ + -xd=2,0,384,6,350,160\ + -gblcar\ + -gfix=0.2,0.1,0.02" + + session_args = f"-dir={self.interim} -run={self.name} -dest={self.interim} " + args + print(f"\ncatgt args of {self.name}: \n{session_args}") + + subprocess.run( ['./run_catgt.sh', session_args]) + + # make sure CatGT_dir is set after running + self.CatGT_dir = sorted(glob.glob( + str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' + )) def load_recording(self): @@ -721,6 +731,7 @@ def load_recording(self): except: for _, files in enumerate(self.files): + # TODO jan 4 check if can put line 798-808 here try: print("\n> Getting catgt-ed recording...") self.CatGT_dir = Path(self.CatGT_dir[0]) @@ -781,22 +792,30 @@ def sort_spikes(self, CatGT_app=None, old=False): #assert 0 #TODO: jan 3 see if ks can run normally now using load_recording() + self.run_catgt(CatGT_app=CatGT_app) + for _, files in enumerate(self.files): if not CatGT_app == None: - self.run_catgt(CatGT_app=CatGT_app) - print("\n> Sorting catgt-ed spikes\n") - self.CatGT_dir = Path(self.CatGT_dir[0]) - data_file = self.CatGT_dir / files['CatGT_ap_data'] - metadata = self.CatGT_dir / files['CatGT_ap_meta'] + basename = self.CatGT_dir[0].split('/')[-1] + files['CatGT_ap_data'] = basename + "/" + files['CatGT_ap_data'] + files['CatGT_ap_meta'] = basename + "/" + files['CatGT_ap_meta'] + data_type = 'CatGT_ap_data' + meta_type = 'CatGT_ap_meta' else: print(f"\n> using the orignial spike data.\n") - data_file = self.find_file(files['spike_data']) - metadata = self.find_file(files['spike_meta']) + data_type = 'spike_data' + meta_type = 'spike_meta' + + data_file = self.find_file(files[data_type]) + metadata = self.find_file(files[meta_type]) - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: - streams[stream_id] = metadata + stream_id = data_file.as_posix()[-12:-4] + if stream_id not in streams: + streams[stream_id] = metadata + assert 0 + #TODO jan 4 check if with multiple streams, do i need to go out of + #the loop for stream_num, stream in enumerate(streams.items()): stream_id, metadata = stream From c149ee5c3aaf07fa1f778c42cb3c713dc241c252 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Jan 2024 16:48:58 +0000 Subject: [PATCH 046/658] add processed dir; fix typo --- pixels/behaviours/base.py | 38 +++++++++++++++++++++++--------------- pixels/experiment.py | 2 ++ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index cd0e209..a382788 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -131,13 +131,23 @@ class Behaviour(ABC): sample_rate = 1000 - def __init__(self, name, data_dir, metadata=None, interim_dir=None): + def __init__(self, name, data_dir, metadata=None, processed_dir=None, + interim_dir=None): self.name = name self.data_dir = data_dir self.metadata = metadata self.raw = self.data_dir / 'raw' / self.name - self.processed = self.data_dir / 'processed' / self.name + + if interim_dir is None: + self.interim = self.data_dir / 'interim' / self.name + else: + self.interim = Path(interim_dir).expanduser() / self.name + + if processed_dir is None: + self.processed = self.data_dir / 'processed' / self.name + else: + self.processed = Path(processed_dir).expanduser() / self.name self.files = ioutils.get_data_files(self.raw, name) @@ -159,11 +169,6 @@ def __init__(self, name, data_dir, metadata=None, interim_dir=None): else: print(f"\n> {self.name} has not been spike-sorted.") - if interim_dir is None: - self.interim = self.data_dir / 'interim' / self.name - else: - self.interim = Path(interim_dir).expanduser() / self.name - self.CatGT_dir = sorted(glob.glob( str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' )) @@ -684,9 +689,9 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: # copy spike data to interim self.find_file(f['spike_data']) - if isinstance(self.CatGT_dir, list) and + if (isinstance(self.CatGT_dir, list) and len(self.CatGT_dir) != 0 and - len(os.listdir(self.CatGT_dir[0])) != 0: + len(os.listdir(self.CatGT_dir[0])) != 0): print(f"\nCatGT already performed on ap data of {self.name}. Next session.\n") return else: @@ -783,7 +788,7 @@ def sort_spikes(self, CatGT_app=None, old=False): streams = {} # set chunks for spikeinterface operations job_kwargs = dict( - n_jobs=20, # -1: num of job equals num of cores + n_jobs=0.9, # -1: num of job equals num of cores chunk_duration="1s", progress_bar=True, ) @@ -813,9 +818,6 @@ def sort_spikes(self, CatGT_app=None, old=False): stream_id = data_file.as_posix()[-12:-4] if stream_id not in streams: streams[stream_id] = metadata - assert 0 - #TODO jan 4 check if with multiple streams, do i need to go out of - #the loop for stream_num, stream in enumerate(streams.items()): stream_id, metadata = stream @@ -834,7 +836,10 @@ def sort_spikes(self, CatGT_app=None, old=False): continue try: - recording = se.SpikeGLXRecordingExtractor(self.CatGT_dir, stream_id=stream_id) + recording = se.SpikeGLXRecordingExtractor( + self.CatGT_dir[0], + stream_id=stream_id, + ) # this recording is filtered recording.annotate(is_filtered=True) except ValueError as e: @@ -860,6 +865,7 @@ def sort_spikes(self, CatGT_app=None, old=False): else: try: ks3_output = si.load_extractor(output / 'saved_si_sorting_obj') + #sorting_KS = read_kilosort(folder_path="kilosort-folder") print("> This session is already sorted, now it is loaded.\n") """ @@ -883,10 +889,11 @@ def sort_spikes(self, CatGT_app=None, old=False): sorter_name='kilosort3', recording=concat_rec, #recording=test, # for testing output_folder=output, - #remove_existing_folder=False, + remove_existing_folder=True, **job_kwargs, ) + assert 0 # remove empty units ks3_output = sorting.remove_empty_units() print(f"> KS3 removed\ @@ -954,6 +961,7 @@ def sort_spikes(self, CatGT_app=None, old=False): ) print(f"> Parameters for manual curation saved to {for_phy}.\n") + assert 0 correct_kslabels = for_phy / "cluster_KSLabel.tsv" if correct_kslabels.exists(): print(f"\nCorrect KS labels already saved in {correct_kslabels}. Next session.\n") diff --git a/pixels/experiment.py b/pixels/experiment.py index ba3805c..8466449 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -50,6 +50,7 @@ def __init__( data_dir, meta_dir=None, interim_dir=None, + processed_dir=None, session_date_fmt="%y%m%d", ): if not isinstance(mouse_ids, (list, tuple, set)): @@ -80,6 +81,7 @@ def __init__( metadata=[s['metadata'] for s in metadata], data_dir=metadata[0]['data_dir'], interim_dir=interim_dir, + processed_dir=processed_dir, ) ) From 12e4a740d6c4cab58058db4741f8951470d28134 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 9 Jan 2024 18:35:03 +0000 Subject: [PATCH 047/658] adding session & stream name to print; use 90% cores; separate cache folder for each stream --- pixels/behaviours/base.py | 47 ++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a382788..1b2b797 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -801,7 +801,7 @@ def sort_spikes(self, CatGT_app=None, old=False): for _, files in enumerate(self.files): if not CatGT_app == None: - print("\n> Sorting catgt-ed spikes\n") + print("\n> Sorting catgt-ed spikes.\n") basename = self.CatGT_dir[0].split('/')[-1] files['CatGT_ap_data'] = basename + "/" + files['CatGT_ap_data'] files['CatGT_ap_meta'] = basename + "/" + files['CatGT_ap_meta'] @@ -830,9 +830,9 @@ def sort_spikes(self, CatGT_app=None, old=False): # check if already sorted and exported for_phy = output / "phy_ks3" if not for_phy.exists() or not len(os.listdir(for_phy)) > 1: - print("> Not sorted or exported yet, start from spike sorting...\n") + print(f"> {self.name}{stream_id} not sorted or exported.\n") else: - print("> Already sorted and exported, next session...\n") + print("> Already sorted and exported, next session.\n") continue try: @@ -866,7 +866,7 @@ def sort_spikes(self, CatGT_app=None, old=False): try: ks3_output = si.load_extractor(output / 'saved_si_sorting_obj') #sorting_KS = read_kilosort(folder_path="kilosort-folder") - print("> This session is already sorted, now it is loaded.\n") + print(f"> {self.name}{stream_id} is already sorted, now it is loaded.\n") """ # for testing: get first 5 mins of the recording @@ -882,8 +882,7 @@ def sort_spikes(self, CatGT_app=None, old=False): """ except: - print("> Running kilosort\n") - print(f"> Now is sorting: \n{concat_rec}\n") + print(f"> Now kilosorting {self.name}{stream_id}: \n{concat_rec}\n") #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) sorting = ss.run_sorter( sorter_name='kilosort3', @@ -893,7 +892,6 @@ def sort_spikes(self, CatGT_app=None, old=False): **job_kwargs, ) - assert 0 # remove empty units ks3_output = sorting.remove_empty_units() print(f"> KS3 removed\ @@ -913,26 +911,28 @@ def sort_spikes(self, CatGT_app=None, old=False): #TODO: toggle load_if_exists=True & overwrite=False should replace #...load_from_folder. + cache = self.interim / f'cache_{stream_num}' try: waveforms = si.WaveformExtractor.load_from_folder( - folder=self.interim / 'cache', + folder=cache, sorting=ks3_output, ) - print("> Waveforms extracted, now it is loaded.\n") + print("> {self.name}{stream_id} waveforms extracted, now it is loaded.\n") except: - print("> Waveforms not extracted, extracting now.\n") + print("> {self.name}{stream_id} waveforms not extracted, extracting now.\n") + #if ks3_output.count_total_num_spikes() # extract waveforms waveforms = si.extract_waveforms( recording=concat_rec, #recording=test, # for testing sorting=ks3_output, - folder=self.interim / 'cache', - #load_if_exists=True, # load extracted if available - load_if_exists=False, # re-calculate everytime + folder=cache, + load_if_exists=True, # load extracted if available + #load_if_exists=False, # re-calculate everytime max_spikes_per_unit=500, # None will extract all waveforms ms_before=2.0, # time before trough ms_after=3.0, # time after trough - #overwrite=False, - overwrite=True, + overwrite=False, + #overwrite=True, **job_kwargs, ) @@ -948,20 +948,19 @@ def sort_spikes(self, CatGT_app=None, old=False): # export to phy, with pc feature calculated. # copy recording.dat to output so that individual waveforms can be # seen in waveformview. - print("\n> Exporting parameters for phy...\n") + print(f"\n> Exporting {self.name}{stream_id} parameters for phy...\n") sexp.export_to_phy( waveform_extractor=waveforms, output_folder=for_phy, compute_pc_features=True, # pca compute_amplitudes=True, - copy_binary=True, + copy_binary=False, #remove_if_exists=True, # overwrite everytime remove_if_exists=False, # load if already exists **job_kwargs, ) print(f"> Parameters for manual curation saved to {for_phy}.\n") - assert 0 correct_kslabels = for_phy / "cluster_KSLabel.tsv" if correct_kslabels.exists(): print(f"\nCorrect KS labels already saved in {correct_kslabels}. Next session.\n") @@ -969,7 +968,7 @@ def sort_spikes(self, CatGT_app=None, old=False): print("\n> Getting all KS labels...") all_ks_labels = pd.read_csv( - output / "cluster_KSLabel.tsv", + output / "sorter_output/cluster_KSLabel.tsv", sep='\t', ) print("\n> Finding cluster ids from spikeinterface output...") @@ -988,6 +987,14 @@ def sort_spikes(self, CatGT_app=None, old=False): index=False, ) + # copy params.py from sorter_output to phy_ks3 + print(f"\n> Copying params.py to {for_phy}...") + copyfile(output / "sorter_output/params.py", for_phy / "params.py") + + # TODO jan 8 in sorter_output, only keep params, recording and + # temp_wh, delete the rest + print(f"\n> {self.name}{stream_id} spike-sorted.") + def extract_videos(self, force=False): """ @@ -2218,7 +2225,7 @@ def get_spike_waveforms(self, units=None, method='phy'): elif method == 'spikeinterface': # set chunks job_kwargs = dict( - n_jobs=10, # -1: num of job equals num of cores + n_jobs=0.9, # -1: num of job equals num of cores chunk_duration="1s", progress_bar=True, ) From 8d3e0ce491584af27c28829ed8ec6df054a782cd Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 7 Mar 2024 16:48:35 +0000 Subject: [PATCH 048/658] change number of cores for joblib multiprocessing --- pixels/behaviours/base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1b2b797..89dec52 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -696,7 +696,7 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: return else: #TODO: finish this here so that catgt can run together with sorting - print(f"> Running CatGT on ap data of {self.name}") + print(f"\n> Running CatGT on ap data of {self.name}") #_dir = self.interim if args == None: @@ -788,7 +788,7 @@ def sort_spikes(self, CatGT_app=None, old=False): streams = {} # set chunks for spikeinterface operations job_kwargs = dict( - n_jobs=0.9, # -1: num of job equals num of cores + n_jobs=-3, # -1: num of job equals num of cores chunk_duration="1s", progress_bar=True, ) @@ -830,7 +830,7 @@ def sort_spikes(self, CatGT_app=None, old=False): # check if already sorted and exported for_phy = output / "phy_ks3" if not for_phy.exists() or not len(os.listdir(for_phy)) > 1: - print(f"> {self.name}{stream_id} not sorted or exported.\n") + print(f"> {self.name} {stream_id} not sorted/exported.\n") else: print("> Already sorted and exported, next session.\n") continue @@ -866,7 +866,7 @@ def sort_spikes(self, CatGT_app=None, old=False): try: ks3_output = si.load_extractor(output / 'saved_si_sorting_obj') #sorting_KS = read_kilosort(folder_path="kilosort-folder") - print(f"> {self.name}{stream_id} is already sorted, now it is loaded.\n") + print(f"> {self.name} {stream_id} is already sorted, now it is loaded.\n") """ # for testing: get first 5 mins of the recording @@ -882,7 +882,7 @@ def sort_spikes(self, CatGT_app=None, old=False): """ except: - print(f"> Now kilosorting {self.name}{stream_id}: \n{concat_rec}\n") + print(f"> Now kilosorting {self.name} {stream_id}: \n{concat_rec}\n") #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) sorting = ss.run_sorter( sorter_name='kilosort3', @@ -917,9 +917,9 @@ def sort_spikes(self, CatGT_app=None, old=False): folder=cache, sorting=ks3_output, ) - print("> {self.name}{stream_id} waveforms extracted, now it is loaded.\n") + print(f"> {self.name} {stream_id} waveforms extracted, now it is loaded.\n") except: - print("> {self.name}{stream_id} waveforms not extracted, extracting now.\n") + print(f"> {self.name} {stream_id} waveforms not extracted, extracting now.\n") #if ks3_output.count_total_num_spikes() # extract waveforms waveforms = si.extract_waveforms( @@ -948,7 +948,7 @@ def sort_spikes(self, CatGT_app=None, old=False): # export to phy, with pc feature calculated. # copy recording.dat to output so that individual waveforms can be # seen in waveformview. - print(f"\n> Exporting {self.name}{stream_id} parameters for phy...\n") + print(f"\n> Exporting {self.name} {stream_id} parameters for phy...\n") sexp.export_to_phy( waveform_extractor=waveforms, output_folder=for_phy, @@ -993,7 +993,7 @@ def sort_spikes(self, CatGT_app=None, old=False): # TODO jan 8 in sorter_output, only keep params, recording and # temp_wh, delete the rest - print(f"\n> {self.name}{stream_id} spike-sorted.") + print(f"\n> {self.name} {stream_id} spike-sorted.\n") def extract_videos(self, force=False): @@ -2225,7 +2225,7 @@ def get_spike_waveforms(self, units=None, method='phy'): elif method == 'spikeinterface': # set chunks job_kwargs = dict( - n_jobs=0.9, # -1: num of job equals num of cores + n_jobs=-3, # -1: num of job equals num of cores chunk_duration="1s", progress_bar=True, ) From d36380cd3519882813ffb8553234f61974be9a3d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 7 Mar 2024 19:18:27 +0000 Subject: [PATCH 049/658] create virtual reality tunnel behaviour --- pixels/behaviours/vitual_tunnel.py | 514 +++++++++++++++++++++++++++++ 1 file changed, 514 insertions(+) create mode 100644 pixels/behaviours/vitual_tunnel.py diff --git a/pixels/behaviours/vitual_tunnel.py b/pixels/behaviours/vitual_tunnel.py new file mode 100644 index 0000000..6d9d16c --- /dev/null +++ b/pixels/behaviours/vitual_tunnel.py @@ -0,0 +1,514 @@ +""" +This module provides reach task specific operations. +""" + +from __future__ import annotations + +import pickle + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from vision_in_darkness.session import Outcome + +from pixels import Experiment, PixelsError +from pixels import signal, ioutils +from pixels.behaviours import Behaviour + + +class ActionLabels: + """ + These actions cover all possible trial types. 'Left' and 'right' correspond to the + trial's correct side i.e. which LED was illuminated. This means `incorrect_left` + trials involved reaches to the right hand target when the left LED was on. + + To align trials to more than one action type they can be bitwise OR'd i.e. + `miss_left | miss_right` will match all miss trials. + """ + miss_light = 1 << 0 + miss_dark = 1 << 1 + correct_light = 1 << 2 + correct_dark = 1 << 3 + punished_light = 1 << 4 + punished_dark = 1 << 5 + #TODO mar 7 continue here + + miss = miss_light | miss_dark + correct = correct_light | correct_dark + punished = punished_light | punished_dark + light = miss_light | correct_light | incorrect_light + dark = miss_dark | correct_dark | incorrect_dark + + # Timepoints determined from motion tracking + clean_left = 1 << 10 # Cued single reach to grasp + clean_right = 1 << 11 + multi_left = 1 << 12 # Cued multiple reaches before reward + multi_right = 1 << 13 + precue_rewarded_left = 1 << 14 # Cued by well-timed spontaneous reach right before + precue_rewarded_right = 1 << 15 + tracking_fail_left = 1 << 16 # Motion tracking failed to get reach trajectory + tracking_fail_right = 1 << 17 + long_reach_duration_left = 1 << 17 + long_reach_duration_right = 1 << 18 + clean = clean_left | clean_right + multi = multi_left | multi_right + precue_rewarded = precue_rewarded_left | precue_rewarded_right + tracking_fail = tracking_fail_left | tracking_fail_right + long_reach_duration = long_reach_duration_left | long_reach_duration_right + + clean_incorrect_left = 1 << 19 # Cued single reach to grasp + clean_incorrect_right = 1 << 20 + multi_incorrect_left = 1 << 21 # Cued multiple reaches before reward + multi_incorrect_right = 1 << 22 + precue_incorrect_left = 1 << 23 # Cued by well-timed spontaneous reach right before + precue_incorrect_right = 1 << 24 + tracking_fail_incorrect_left = 1 << 25 # Motion tracking failed to get reach trajectory + tracking_fail_incorrect_right = 1 << 26 + long_reach_duration_incorrect_left = 1 << 27 + long_reach_duration_incorrect_right = 1 << 28 + clean_incorrect = clean_incorrect_left | clean_incorrect_right + multi_incorrect = multi_incorrect_left | multi_incorrect_right + precue_incorrect = precue_incorrect_left | precue_incorrect_right + tracking_fail_incorrect = tracking_fail_incorrect_left | tracking_fail_incorrect_right + long_reach_duration_incorrect = long_reach_duration_incorrect_left | long_reach_duration_incorrect_right + + +class Events: + led_on = 1 << 0 + led_off = 1 << 1 + + # Timepoints determined from motion tracking + reach_onset = 1 << 2 + slit_in = 1 << 3 + grasp = 1 << 4 + slit_out = 1 << 5 + subsequent_slit_in = 1 << 6 # The SECOND full reach on a clean correct trial only + subsequent_grasp = 1 << 7 + subsequent_slit_out = 1 << 8 + + +# These are used to convert the trial data into Actions and Events +_side_map = { + Targets.LEFT: "left", + Targets.RIGHT: "right", +} + +_action_map = { + Outcomes.MISSED: "miss", + Outcomes.CORRECT: "correct", + Outcomes.INCORRECT: "incorrect", +} + + + +class Reach(Behaviour): + def _preprocess_behaviour(self, rec_num, behavioural_data): + # Correction for sessions where sync channel interfered with LED channel + if behavioural_data["/'ReachLEDs'/'0'"].min() < -2: + behavioural_data["/'ReachLEDs'/'0'"] = behavioural_data["/'ReachLEDs'/'0'"] \ + + 0.5 * behavioural_data["/'NpxlSync_Signal'/'0'"] + + behavioural_data = signal.binarise(behavioural_data) + action_labels = np.zeros((len(behavioural_data), 2), dtype=np.uint64) + + try: + cue_leds = behavioural_data["/'ReachLEDs'/'0'"].values + except KeyError: + # some early recordings still used this key + cue_leds = behavioural_data["/'Back_Sensor'/'0'"].values + + led_onsets = np.where((cue_leds[:-1] == 0) & (cue_leds[1:] == 1))[0] + led_offsets = np.where((cue_leds[:-1] == 1) & (cue_leds[1:] == 0))[0] + action_labels[led_onsets, 1] += Events.led_on + action_labels[led_offsets, 1] += Events.led_off + metadata = self.metadata[rec_num] + + # QA: Check that the JSON and TDMS data have the same number of trials + if len(led_onsets) != len(metadata["trials"]): + # If they do not have the same number, perhaps the TDMS was stopped too early + meta_onsets = np.array([t["start"] for t in metadata["trials"]]) * 1000 + meta_onsets = (meta_onsets - meta_onsets[0] + led_onsets[0]).astype(int) + if meta_onsets[-1] > len(cue_leds): + # TDMS stopped too early, continue anyway. + i = -1 + while meta_onsets[i] > len(cue_leds): + metadata["trials"].pop() + i -= 1 + assert len(led_onsets) == len(metadata["trials"]) + else: + # If you have come to debug and see why this error was raised, try: + # led_onsets - meta_onsets[:len(led_onsets)] # This might show the problem + # meta_onsets - led_onsets[:len(meta_onsets)] # This might show the problem + # Then just patch a fix here: + if self.name == "211027_VR49" and rec_num == 1: + del metadata["trials"][52] # Maybe cable fell out of DAQ input? + else: + raise PixelsError( + f"{self.name}: Mantis and Raspberry Pi behavioural " + "data have different no. of trials" + ) + + # QA: Last offset not found in tdms data? + if len(led_offsets) < len(led_onsets): + last_trial = self.metadata[rec_num]['trials'][-1] + if "end" in last_trial: + # Take known offset from metadata + offset = led_onsets[-1] + (last_trial['end'] - last_trial['start']) * 1000 + led_offsets = np.append(led_offsets, int(offset)) + else: + # If not possible, just remove last onset + led_onsets = led_onsets[:-1] + metadata["trials"].pop() + assert len(led_offsets) == len(led_onsets) + + # QA: For some reason, sometimes the final trial metadata doesn't include the + # final led-off even though it is detectable in the TDMS data. + elif len(led_offsets) == len(led_onsets): + # Not sure how to deal with this if led_offsets and led_onsets differ in length + if len(metadata["trials"][-1]) == 1 and "start" in metadata["trials"][-1]: + # Remove it, because we would have to check the video to get all of the + # information about the trial, and it's too complicated. + metadata["trials"].pop() + led_onsets = led_onsets[:-1] + led_offsets = led_offsets[:-1] + + # QA: Check that the cue durations (mostly) match between JSON and TDMS data + # This compares them at 10s of milliseconds resolution + cue_durations_tdms = (led_offsets - led_onsets) / 100 + cue_durations_json = np.array( + [t['end'] - t['start'] for t in metadata['trials']] + ) * 10 + error = sum( + (cue_durations_tdms - cue_durations_json).round() != 0 + ) / len(led_onsets) + if error > 0.05: + raise PixelsError( + f"{self.name}: Mantis and Raspberry Pi behavioural data have mismatching trial data." + ) + + return behavioural_data, action_labels, led_onsets + + def _extract_action_labels(self, rec_num, behavioural_data, plot=False): + behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(rec_num, behavioural_data) + + for i, trial in enumerate(self.metadata[rec_num]["trials"]): + side = _side_map[trial["spout"]] + outcome = trial["outcome"] + if outcome in _action_map: + action = _action_map[trial["outcome"]] + action_labels[led_onsets[i], 0] += getattr(ActionLabels, f"{action}_{side}") + + if plot: + plt.clf() + _, axes = plt.subplots(4, 1, sharex=True, sharey=True) + axes[0].plot(back_sensor_signal) + if "/'Back_Sensor'/'0'" in behavioural_data: + axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) + else: + axes[1].plot(behavioural_data["/'ReachCue_LEDs'/'0'"].values) + axes[2].plot(action_labels[:, 0]) + axes[3].plot(action_labels[:, 1]) + plt.plot(action_labels[:, 1]) + plt.show() + + return action_labels + + def draw_slit_thresholds(self, project: str, force: bool = False): + """ + Draw lines on the slits using EasyROI. If ROIs already exist, skip. + + Parameters + ========== + project : str + The DLC project i.e. name/prefix of the camera. + + force : bool + If true, we will draw new lines even if the output file exists. + + """ + # Only needed for this method + import cv2 + import EasyROI + + output = self.processed / f"slit_thresholds_{project}.pickle" + + if output.exists() and not force: + print(self.name, "- slits drawn already.") + return + + # Let's take the average between the first and last frames of the whole session. + videos = [] + + for recording in self.files: + for v, video in enumerate(recording.get("camera_data", [])): + if project in video.stem: + avi = self.interim / video.with_suffix('.avi') + if not avi.exists(): + meta = recording['camera_meta'][v] + ioutils.tdms_to_video( + self.find_file(video, copy=False), + self.find_file(meta), + avi, + ) + if not avi.exists(): + raise PixelsError(f"Path {avi} should exist but doesn't... discuss.") + videos.append(avi.as_posix()) + + if not videos: + raise PixelsError("No videos were found to draw slits on.") + + first_frame = ioutils.load_video_frame(videos[0], 1) + last_duration = ioutils.get_video_dimensions(videos[-1])[2] + last_frame = ioutils.load_video_frame(videos[-1], last_duration - 1) + + average_frame = np.concatenate( + [first_frame[..., None], last_frame[..., None]], + axis=2, + ).mean(axis=2) + average_frame = np.squeeze(average_frame) / 255 + + # Interactively draw ROI + global _roi_helper + if _roi_helper is None: + # Ugly but we can only have one instance of this + _roi_helper = EasyROI.EasyROI(verbose=False) + lines = _roi_helper.draw_line(average_frame, 2) + cv2.destroyAllWindows() # Needed otherwise EasyROI errors + + # Save a copy of the frame with ROIs to PNG file + png = output.with_suffix(".png") + copy = EasyROI.visualize_line(average_frame, lines, color=(255, 0, 0)) + plt.imsave(png, copy, cmap='gray') + + # Save lines to file + with output.open('wb') as fd: + pickle.dump(lines['roi'], fd) + + def inject_slit_crossings(self): + """ + Take the lines drawn from `draw_slit_thresholds` above, get the reach + coordinates from DLC output, identify the timepoints when successful reaches + crossed the lines, and add the `reach_onset` event to those timepoints in the + action labels. Also identify which trials need clearing up, i.e. those with + multiple reaches or have failed motion tracking, and exclude those. + """ + lines = {} + projects = ("LeftCam", "RightCam") + + for project in projects: + line_file = self.processed / f"slit_thresholds_{project}.pickle" + + if not line_file.exists(): + print(self.name, "- Lines not drawn for session.") + return + + with line_file.open("rb") as f: + proj_lines = pickle.load(f) + + lines[project] = { + tt:[pd.Series(p, index=["x", "y"]) for p in points.values()] + for tt, points in proj_lines.items() + } + + action_labels = self.get_action_labels() + event = Events.led_off + + # https://bryceboe.com/2006/10/23/line-segment-intersection-algorithm + def ccw(A, B, C): + return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) + + for tt, action in enumerate( + (ActionLabels.correct_left, ActionLabels.correct_right), + ): + data = {} + trajectories = {} + + for project in projects: + proj_data = self.align_trials( + action, + event, + "motion_tracking", + duration=6, + dlc_project=project, + ) + proj_traj, = check_scorers(proj_data) + data[project] = proj_traj + trajectories[project] = get_reach_trajectories(proj_traj)[0] + + for rec_num, recording in enumerate(self.files): + actions = action_labels[rec_num][:, 0] + events = action_labels[rec_num][:, 1] + trial_starts = np.where(np.bitwise_and(actions, action))[0] + + for t, start in enumerate(trial_starts): + centre = np.where(np.bitwise_and(events[start:start + 6000], event))[0] + if len(centre) == 0: + raise PixelsError('Action labels probably miscalculated') + centre = start + centre[0] + # centre is index of this rec's grasp + onsets = [] + for project, motion in trajectories.items(): + left_hand = motion["left_hand_median"][t][:0] + right_hand = motion["right_hand_median"][t][:0] + pt1, pt2 = lines[project][tt] + + x_l = left_hand.iloc[-10:]["x"].mean() + x_r = right_hand.iloc[-10:]["x"].mean() + hand = right_hand + if project == "LeftCam": + if x_l > x_r: + hand = left_hand + else: + if x_r > x_l: + hand = left_hand + + segments = zip( + hand.iloc[::-1].iterrows(), + hand.iloc[-2::-1].iterrows(), + ) + + for (end, ptend), (start, ptsta) in segments: + if ( + ccw(pt1, ptsta, ptend) != ccw(pt2, ptsta, ptend) and + ccw(pt1, pt2, ptsta) != ccw(pt1, pt2, ptend) + ): + # These lines intersect + #print("x from ", ptsta.x, " to ", ptend.x) + break + #if ptend.y > 300 or ptsta.y > 300: + # assert 0 + onsets.append(start) + + onset = max(onsets) + onset_timepoint = round(centre + (onset * 1000)) + events[onset_timepoint] |= Events.reach_onset + + output = self.processed / recording['action_labels'] + np.save(output, action_labels[rec_num]) + + +global _roi_helper +_roi_helper = None + + +class VisualOnly(Reach): + def _extract_action_labels(self, behavioural_data, plot=False): + behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(behavioural_data) + + for i, trial in enumerate(self.metadata["trials"]): + label = "naive_" + _side_map[trial["spout"]] + "_" + if trial["cue_duration"] > 125: + label += "long" + else: + label += "short" + action_labels[led_onsets[i], 0] += getattr(ActionLabels, label) + + return action_labels + + +def get_reach_velocities(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Get the velocity curves for the provided reach trajectories. + """ + results = [] + + for df in dfs: + df = df.copy() + deltas = np.square(df.iloc[1:].values - df.iloc[:-1].values) + # Fill the start with a row of zeros - each value is delta in previous 1 ms + deltas = np.append(np.zeros((1, deltas.shape[1])), deltas, axis=0) + deltas = np.sqrt(deltas[:, ::2] + deltas[:, 1::2]) + df = df.drop([c for c in df.columns if "y" in c], axis=1) + df = df.rename({"x": "delta"}, axis='columns') + df.values[:] = deltas + results.append(df) + + return tuple(results) + + +def get_reach_trajectories(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Get the median centre point of the hand coordinates - i.e. for the labels for each + of the four digits and hand centre for both paws. + """ + assert dfs + bodyparts = get_body_parts(dfs[0]) + right_paw = [p for p in bodyparts if p.startswith("right")] + left_paw = [p for p in bodyparts if p.startswith("left")] + + trajectories_l = [] + trajectories_r = [] + + for df in dfs: + per_ses_l = [] + per_ses_r = [] + + # Ugly hack so this function works on single sessions + if "session" not in df.columns.names: + df = pd.concat([df], keys=[0], axis=1) + sessions = [0] + single_session = True + else: + single_session = False + sessions = df.columns.get_level_values("session").unique() + + for s in sessions: + per_trial_l = [] + per_trial_r = [] + + trials = df[s].columns.get_level_values("trial").unique() + for t in trials: + tdf = df[s][t] + left = pd.concat((tdf[p] for p in left_paw), keys=left_paw) + right = pd.concat((tdf[p] for p in right_paw), keys=right_paw) + per_trial_l.append(left.groupby(level=1).median()) + per_trial_r.append(right.groupby(level=1).median()) + + per_ses_l.append(pd.concat(per_trial_l, axis=1, keys=trials)) + per_ses_r.append(pd.concat(per_trial_r, axis=1, keys=trials)) + + trajectories_l.append(pd.concat(per_ses_l, axis=1, keys=sessions)) + trajectories_r.append(pd.concat(per_ses_r, axis=1, keys=sessions)) + + if single_session: + return tuple( + pd.concat( + [trajectories_l[i][0], trajectories_r[i][0]], + axis=1, + keys=["left_hand_median", "right_hand_median"], + ) + for i in range(len(dfs)) + ) + return tuple( + pd.concat( + [trajectories_l[i], trajectories_r[i]], + axis=1, + keys=["left_hand_median", "right_hand_median"], + ) + for i in range(len(dfs)) + ) + + +def check_scorers(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: + """ + Checks that the scorers are identical for all data in the dataframes. These are the + dataframes as returned from Exp.align_trials for motion_tracking data. + + It returns the dataframes with the scorer index level removed. + """ + scorers = set( + s + for df in dfs + for s in df.columns.get_level_values("scorer").unique() + ) + + assert len(scorers) == 1, scorers + return tuple(df.droplevel("scorer", axis=1) for df in dfs) + + +def get_body_parts(df: pd.DataFrame) -> list[str]: + """ + Get the list of body part labels. + """ + return df.columns.get_level_values("bodyparts").unique() From 41075542d0b74bd534360a1e92381ee34e9bfd78 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 28 May 2024 19:04:12 +0100 Subject: [PATCH 050/658] make sure only finds folders with given mouse id --- pixels/ioutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index f5d9578..fbc3111 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -348,7 +348,7 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): raw_dir = data_dir / 'raw' for mouse in mouse_ids: - mouse_sessions = list(raw_dir.glob(f'*{mouse}*')) + mouse_sessions = list(raw_dir.glob(f'*{mouse}')) if not mouse_sessions: print(f'Found no sessions for: {mouse}') From e15b0e8dc1cd35b069afa76925f8440019ed4cc5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 28 May 2024 19:04:56 +0100 Subject: [PATCH 051/658] add comments to make it more readable --- pixels/signal.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/pixels/signal.py b/pixels/signal.py index 61254de..0d8cba1 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -101,7 +101,7 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): current += chunk_size print(f" {100 * current / cols:.1f}%", end="\r") - return np.concatenate(new_data, axis=1) #.astype(np.int16) + return np.concatenate(new_data, axis=1).squeeze()#.astype(np.int16) def binarise(data): @@ -139,8 +139,8 @@ def find_sync_lag(array1, array2, plot=False): Parameters ---------- array1 : array, Series or similar - The first array. A positive result indicates that this array has leading data - not present in the second array. e.g. if lag == 5 then array2 starts on the 5th + The first array. A positive result indicates that THIS ARRAY HAS LEADING DATA + NOT PRESENT IN THE SECOND ARRAY. e.g. if lag == 5 then array2 starts on the 5th index of array1. array2 : array, Series or similar @@ -165,26 +165,32 @@ def find_sync_lag(array1, array2, plot=False): array1 = array1.squeeze() array2 = array2.squeeze() - sync_p = [] + sync_pos = [] for i in range(length): + # finds how many values are the same in array1 as in array2 till given length matches = np.count_nonzero(array1[i:i + length] == array2[:length]) - sync_p.append(100 * matches / length) - match_p = max(sync_p) - lag_p = sync_p.index(match_p) - - sync_n = [] + # append the percentage of match given length + sync_pos.append(100 * matches / length) + # take the highest percentage during checks as the match + match_pos = max(sync_pos) + # find index where lag started in array1 + lag_pos = sync_pos.index(match_pos) + + sync_neg = [] for i in range(length): + # finds how many values are the same in array2 as in array1 till given length matches = np.count_nonzero(array2[i:i + length] == array1[:length]) - sync_n.append(100 * matches / length) - match_n = max(sync_n) - lag_n = sync_n.index(match_n) - - if match_p > match_n: - lag = lag_p - match = match_p + # append the percentage of match given length + sync_neg.append(100 * matches / length) + match_neg = max(sync_neg) + lag_neg = sync_neg.index(match_neg) + + if match_pos > match_neg: + lag = lag_pos + match = match_pos else: - lag = - lag_n - match = match_n + lag = - lag_neg + match = match_neg if plot: plot = Path(plot) From 0e6b7dc3df7c3c20d6e783e3eb9eccd044c24f9d Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 28 May 2024 19:05:31 +0100 Subject: [PATCH 052/658] clean up vr action labels --- pixels/behaviours/vitual_tunnel.py | 59 +++++++++++------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/pixels/behaviours/vitual_tunnel.py b/pixels/behaviours/vitual_tunnel.py index 6d9d16c..a7a9523 100644 --- a/pixels/behaviours/vitual_tunnel.py +++ b/pixels/behaviours/vitual_tunnel.py @@ -26,57 +26,42 @@ class ActionLabels: To align trials to more than one action type they can be bitwise OR'd i.e. `miss_left | miss_right` will match all miss trials. """ + # triggered events miss_light = 1 << 0 miss_dark = 1 << 1 correct_light = 1 << 2 correct_dark = 1 << 3 punished_light = 1 << 4 punished_dark = 1 << 5 - #TODO mar 7 continue here + # given reward + default_light = 1 << 6 + default_dark = 1 << 7 + auto_light = 1 << 8 + auto_dark = 1 << 9 + reinf_light = 1 << 10 + reinf_dark = 1 << 11 + + # combos miss = miss_light | miss_dark correct = correct_light | correct_dark punished = punished_light | punished_dark - light = miss_light | correct_light | incorrect_light - dark = miss_dark | correct_dark | incorrect_dark - # Timepoints determined from motion tracking - clean_left = 1 << 10 # Cued single reach to grasp - clean_right = 1 << 11 - multi_left = 1 << 12 # Cued multiple reaches before reward - multi_right = 1 << 13 - precue_rewarded_left = 1 << 14 # Cued by well-timed spontaneous reach right before - precue_rewarded_right = 1 << 15 - tracking_fail_left = 1 << 16 # Motion tracking failed to get reach trajectory - tracking_fail_right = 1 << 17 - long_reach_duration_left = 1 << 17 - long_reach_duration_right = 1 << 18 - clean = clean_left | clean_right - multi = multi_left | multi_right - precue_rewarded = precue_rewarded_left | precue_rewarded_right - tracking_fail = tracking_fail_left | tracking_fail_right - long_reach_duration = long_reach_duration_left | long_reach_duration_right - - clean_incorrect_left = 1 << 19 # Cued single reach to grasp - clean_incorrect_right = 1 << 20 - multi_incorrect_left = 1 << 21 # Cued multiple reaches before reward - multi_incorrect_right = 1 << 22 - precue_incorrect_left = 1 << 23 # Cued by well-timed spontaneous reach right before - precue_incorrect_right = 1 << 24 - tracking_fail_incorrect_left = 1 << 25 # Motion tracking failed to get reach trajectory - tracking_fail_incorrect_right = 1 << 26 - long_reach_duration_incorrect_left = 1 << 27 - long_reach_duration_incorrect_right = 1 << 28 - clean_incorrect = clean_incorrect_left | clean_incorrect_right - multi_incorrect = multi_incorrect_left | multi_incorrect_right - precue_incorrect = precue_incorrect_left | precue_incorrect_right - tracking_fail_incorrect = tracking_fail_incorrect_left | tracking_fail_incorrect_right - long_reach_duration_incorrect = long_reach_duration_incorrect_left | long_reach_duration_incorrect_right + # trial type combos + light = miss_light | correct_light | punished_light | default_light | + auto_light | reinf_light + dark = miss_dark | correct_dark | punished_dark | default_dark | + auto_dark | reinf_dark + rewarded_light = correct_light | default_light | auto_light | reinf_light + rewarded_dark = correct_dark | default_dark | auto_dark | reinf_dark + given_light = default_light | auto_light | reinf_light + given_dark = default_dark | auto_dark | reinf_dark + #TODO mar 7 continue here class Events: - led_on = 1 << 0 - led_off = 1 << 1 + gray_on = 1 << 0 + tunnel_on = 1 << 1 # Timepoints determined from motion tracking reach_onset = 1 << 2 From 66e6598d5c58c51e3c0944567360b1665b4b7bcc Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Jun 2024 19:04:58 +0100 Subject: [PATCH 053/658] give vr proper name; extract some action labels --- pixels/behaviours/virtual_reality.py | 233 +++++++++++++ pixels/behaviours/vitual_tunnel.py | 499 --------------------------- 2 files changed, 233 insertions(+), 499 deletions(-) create mode 100644 pixels/behaviours/virtual_reality.py delete mode 100644 pixels/behaviours/vitual_tunnel.py diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py new file mode 100644 index 0000000..c714799 --- /dev/null +++ b/pixels/behaviours/virtual_reality.py @@ -0,0 +1,233 @@ +""" +This module provides reach task specific operations. +""" + +from __future__ import annotations + +import pickle + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from vision_in_darkness.session import Outcome, World + +from pixels import Experiment, PixelsError +from pixels import signal, ioutils +from pixels.behaviours import Behaviour + +from common_utils import file_utils + +SAMPLE_RATE = 1000 + +class ActionLabels: + """ + These actions cover all possible trial types. + + To align trials to more than one action type they can be bitwise OR'd i.e. + `miss_light | miss_dark` will match all miss trials. + + Actions can NOT be added on top of each other, they should be mutually + exclusive. + """ + # TODO jun 7 2024 does the name "action" make sense? + + # triggered vr trials + miss_light = 1 << 0 + miss_dark = 1 << 1 + triggered_light = 1 << 2 + triggered_dark = 1 << 3 + punished_light = 1 << 4 + punished_dark = 1 << 5 + + # given reward + default_light = 1 << 6 + auto_light = 1 << 7 + auto_dark = 1 << 8 + reinf_light = 1 << 9 + reinf_dark = 1 << 10 + + # combos + miss = miss_light | miss_dark + triggered = triggered_light | triggered_dark + punished = punished_light | punished_dark + + # trial type combos + light = miss_light | triggered_light | punished_light | default_light | + auto_light | reinf_light + dark = miss_dark | triggered_dark | punished_dark | auto_dark | reinf_dark + rewarded_light = triggered_light | default_light | auto_light | reinf_light + rewarded_dark = triggered_dark | auto_dark | reinf_dark + given_light = default_light | auto_light | reinf_light + given_dark = auto_dark | reinf_dark + + +class Events: + """ + Defines events that could happen during vr sessions. + + Events can be added on top of each other. + """ + # vr events + gray_on = 1 << 0 + gray_off = 1 << 1 + tunnel_on = 1 << 2 + tunnel_off = 1 << 3 + dark_on = 1 << 4 + dark_off = 1 << 5 + punish_on = 1 << 6 + punish_off = 1 << 7 + session_end = 1 << 8 + + # positional events + black = 1 << 9 # 0 - 60 cm + wall = 1 << 10 # in between landmarks + landmark1 = 1 << 11 # 110 - 130 cm + landmark2 = 1 << 12 # 190 - 210 cm + landmark3 = 1 << 13 # 270 - 290 cm + landmark4 = 1 << 14 # 350 - 370 cm + landmark5 = 1 << 15 # 430 - 450 cm + reward_zone = 1 << 16 # 460 - 495 cm + + # sensors + valve_open = 1 << 17 + valve_closed = 1 << 18 + licked = 1 << 19 + #run_start = 1 << 12 + #run_stop = 1 << 13 + + +# convert the trial data into Actions and Events +_action_map = { + Outcome.ABORTED_DARK: "miss_dark", + Outcome.ABORTED_LIGHT: "miss_light", + Outcome.NONE: "miss", + Outcome.TRIGGERED: "triggered", + Outcome.AUTO_LIGHT: "auto_light", + Outcome.DEFAULT: "default_light", + Outcome.REINF_LIGHT: "reinf_light", + Outcome.AUTO_DARK: "auto_dark", + Outcome.REINF_DARK: "reinf_dark", +} + + + +class VR(Behaviour): + + def _extract_action_labels(self, vr_data): + + # TEMPORARY load synced vr data + cache_dir = "/home/amz/interim/behaviour_cache/temp100/trials/" + vr_data = file_utils.load_pickle( + cache_dir + "20231130_az_WDAN07_upsampled.pickle" + #self.cache_dir + 'trials/' + self.name + '_upsampled.pickle' + ) + + # create action label array for actions & events + action_label = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + + print(">> Mapping vr event times...") + + # get gray_on times, i.e., trial starts + gray_idx = vr_data.world_index[vr_data.world_index == World.GRAY].index + # grays + grays = np.where(gray_idx.diff() != 1)[0] + + # first frame of gray + gray_on = gray_idx[grays].values + action_labels[gray_on, 1] += Events.gray_on + + # last frame of gray + gray_off = np.append(gray_idx[grays[1:] - 1], gray_idx[-1]) + action_labels[gray_off, 1] += Events.gray_off + + # get tunnel_on times, i.e., tunnel starts + tunnel_idx = vr_data.world_index[vr_data.world_index == World.TUNNEL].index + # tunnels + tunnels = np.where(tunnel_idx.diff() != 1)[0] + + # tunnel starts is first frame after gray + tunnel_on = gray_off + 1 + action_labels[tunnel_on, 1] += Events.tunnel_on + + # get usual tunnel_off times + tunnel_off = np.append(tunnel_idx[tunnels[1:] - 1], tunnel_idx[-1]) + action_labels[tunnel_off, 1] += Events.tunnel_off + + # get dark on + # define in dark + in_dark = (vr_data.world_index == World.DARK_5)\ + |(vr_data.world_index == World.DARK_2_5)\ + |(vr_data.world_index == World.DARK_FULL) + # get index in dark + in_dark_idx = vr_data[in_dark].index + # number of dark_on does not match with number of dark trials caused by + # triggering punishment before dark + # darks + darks = np.where(in_dark_idx.diff() != 1)[0] + + # first frame of dark + dark_on = in_dark_idx[darks].values + action_labels[dark_on, 1] += Events.dark_on + + # last frame of dark + dark_off = np.append(in_dark_idx[darks[1:] - 1], in_dark_idx[-1]) + action_labels[dark_off, 1] += Events.dark_off + + # map licks + licked_idx = vr_data.lick_count[vr_data.lick_count == 1].index.values + action_labels[licked_idx, 1] += Events.licked + + + print(">> Mapping vr action times...") + # map reward zone + in_zone = (vr_data.position_in_tunnel >= self.reward_zone_start)\ + & (vr_data.position_in_tunnel <= self.reward_zone_end) + # get in reward zone index + in_zone_idx = vr_data[in_zone].index + # get reward type while in reward zone + reward_type = vr_data.reward_type.loc[in_zone_idx] + + # after reward zone before trial resets + pass_zone = (vr_data.position_in_tunnel > self.reward_zone_end)\ + & (vr_data.position_in_tunnel <= self.tunnel_length) + # get passed reward zone index + pass_zone_idx = vr_data[pass_zone].index + end_reward_type = vr_data.reward_type.loc[pass_zone_idx] + + # missed light trials + miss_light = end_reward_type[end_reward_type == + Outcome.ABORTED_LIGHT].index.values + # missed dark trials + miss_dark = end_reward_type[end_reward_type == + Outcome.ABORTED_DARK].index.values + + # TODO jun 7 2024 action labels not mapped + + + return action_labels + + + def _old_extract_action_labels(self, vr_data, action_labels, plot=False): + + for i, trial in enumerate(self.metadata[rec_num]["trials"]): + side = _side_map[trial["spout"]] + outcome = trial["outcome"] + if outcome in _action_map: + action = _action_map[trial["outcome"]] + action_labels[led_onsets[i], 0] += getattr(ActionLabels, f"{action}_{side}") + + if plot: + plt.clf() + _, axes = plt.subplots(4, 1, sharex=True, sharey=True) + axes[0].plot(back_sensor_signal) + if "/'Back_Sensor'/'0'" in behavioural_data: + axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) + else: + axes[1].plot(behavioural_data["/'ReachCue_LEDs'/'0'"].values) + axes[2].plot(action_labels[:, 0]) + axes[3].plot(action_labels[:, 1]) + plt.plot(action_labels[:, 1]) + plt.show() + + return action_labels diff --git a/pixels/behaviours/vitual_tunnel.py b/pixels/behaviours/vitual_tunnel.py deleted file mode 100644 index a7a9523..0000000 --- a/pixels/behaviours/vitual_tunnel.py +++ /dev/null @@ -1,499 +0,0 @@ -""" -This module provides reach task specific operations. -""" - -from __future__ import annotations - -import pickle - -import numpy as np -import matplotlib.pyplot as plt -import pandas as pd - -from vision_in_darkness.session import Outcome - -from pixels import Experiment, PixelsError -from pixels import signal, ioutils -from pixels.behaviours import Behaviour - - -class ActionLabels: - """ - These actions cover all possible trial types. 'Left' and 'right' correspond to the - trial's correct side i.e. which LED was illuminated. This means `incorrect_left` - trials involved reaches to the right hand target when the left LED was on. - - To align trials to more than one action type they can be bitwise OR'd i.e. - `miss_left | miss_right` will match all miss trials. - """ - # triggered events - miss_light = 1 << 0 - miss_dark = 1 << 1 - correct_light = 1 << 2 - correct_dark = 1 << 3 - punished_light = 1 << 4 - punished_dark = 1 << 5 - - # given reward - default_light = 1 << 6 - default_dark = 1 << 7 - auto_light = 1 << 8 - auto_dark = 1 << 9 - reinf_light = 1 << 10 - reinf_dark = 1 << 11 - - # combos - miss = miss_light | miss_dark - correct = correct_light | correct_dark - punished = punished_light | punished_dark - - # trial type combos - light = miss_light | correct_light | punished_light | default_light | - auto_light | reinf_light - dark = miss_dark | correct_dark | punished_dark | default_dark | - auto_dark | reinf_dark - rewarded_light = correct_light | default_light | auto_light | reinf_light - rewarded_dark = correct_dark | default_dark | auto_dark | reinf_dark - given_light = default_light | auto_light | reinf_light - given_dark = default_dark | auto_dark | reinf_dark - - #TODO mar 7 continue here - -class Events: - gray_on = 1 << 0 - tunnel_on = 1 << 1 - - # Timepoints determined from motion tracking - reach_onset = 1 << 2 - slit_in = 1 << 3 - grasp = 1 << 4 - slit_out = 1 << 5 - subsequent_slit_in = 1 << 6 # The SECOND full reach on a clean correct trial only - subsequent_grasp = 1 << 7 - subsequent_slit_out = 1 << 8 - - -# These are used to convert the trial data into Actions and Events -_side_map = { - Targets.LEFT: "left", - Targets.RIGHT: "right", -} - -_action_map = { - Outcomes.MISSED: "miss", - Outcomes.CORRECT: "correct", - Outcomes.INCORRECT: "incorrect", -} - - - -class Reach(Behaviour): - def _preprocess_behaviour(self, rec_num, behavioural_data): - # Correction for sessions where sync channel interfered with LED channel - if behavioural_data["/'ReachLEDs'/'0'"].min() < -2: - behavioural_data["/'ReachLEDs'/'0'"] = behavioural_data["/'ReachLEDs'/'0'"] \ - + 0.5 * behavioural_data["/'NpxlSync_Signal'/'0'"] - - behavioural_data = signal.binarise(behavioural_data) - action_labels = np.zeros((len(behavioural_data), 2), dtype=np.uint64) - - try: - cue_leds = behavioural_data["/'ReachLEDs'/'0'"].values - except KeyError: - # some early recordings still used this key - cue_leds = behavioural_data["/'Back_Sensor'/'0'"].values - - led_onsets = np.where((cue_leds[:-1] == 0) & (cue_leds[1:] == 1))[0] - led_offsets = np.where((cue_leds[:-1] == 1) & (cue_leds[1:] == 0))[0] - action_labels[led_onsets, 1] += Events.led_on - action_labels[led_offsets, 1] += Events.led_off - metadata = self.metadata[rec_num] - - # QA: Check that the JSON and TDMS data have the same number of trials - if len(led_onsets) != len(metadata["trials"]): - # If they do not have the same number, perhaps the TDMS was stopped too early - meta_onsets = np.array([t["start"] for t in metadata["trials"]]) * 1000 - meta_onsets = (meta_onsets - meta_onsets[0] + led_onsets[0]).astype(int) - if meta_onsets[-1] > len(cue_leds): - # TDMS stopped too early, continue anyway. - i = -1 - while meta_onsets[i] > len(cue_leds): - metadata["trials"].pop() - i -= 1 - assert len(led_onsets) == len(metadata["trials"]) - else: - # If you have come to debug and see why this error was raised, try: - # led_onsets - meta_onsets[:len(led_onsets)] # This might show the problem - # meta_onsets - led_onsets[:len(meta_onsets)] # This might show the problem - # Then just patch a fix here: - if self.name == "211027_VR49" and rec_num == 1: - del metadata["trials"][52] # Maybe cable fell out of DAQ input? - else: - raise PixelsError( - f"{self.name}: Mantis and Raspberry Pi behavioural " - "data have different no. of trials" - ) - - # QA: Last offset not found in tdms data? - if len(led_offsets) < len(led_onsets): - last_trial = self.metadata[rec_num]['trials'][-1] - if "end" in last_trial: - # Take known offset from metadata - offset = led_onsets[-1] + (last_trial['end'] - last_trial['start']) * 1000 - led_offsets = np.append(led_offsets, int(offset)) - else: - # If not possible, just remove last onset - led_onsets = led_onsets[:-1] - metadata["trials"].pop() - assert len(led_offsets) == len(led_onsets) - - # QA: For some reason, sometimes the final trial metadata doesn't include the - # final led-off even though it is detectable in the TDMS data. - elif len(led_offsets) == len(led_onsets): - # Not sure how to deal with this if led_offsets and led_onsets differ in length - if len(metadata["trials"][-1]) == 1 and "start" in metadata["trials"][-1]: - # Remove it, because we would have to check the video to get all of the - # information about the trial, and it's too complicated. - metadata["trials"].pop() - led_onsets = led_onsets[:-1] - led_offsets = led_offsets[:-1] - - # QA: Check that the cue durations (mostly) match between JSON and TDMS data - # This compares them at 10s of milliseconds resolution - cue_durations_tdms = (led_offsets - led_onsets) / 100 - cue_durations_json = np.array( - [t['end'] - t['start'] for t in metadata['trials']] - ) * 10 - error = sum( - (cue_durations_tdms - cue_durations_json).round() != 0 - ) / len(led_onsets) - if error > 0.05: - raise PixelsError( - f"{self.name}: Mantis and Raspberry Pi behavioural data have mismatching trial data." - ) - - return behavioural_data, action_labels, led_onsets - - def _extract_action_labels(self, rec_num, behavioural_data, plot=False): - behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(rec_num, behavioural_data) - - for i, trial in enumerate(self.metadata[rec_num]["trials"]): - side = _side_map[trial["spout"]] - outcome = trial["outcome"] - if outcome in _action_map: - action = _action_map[trial["outcome"]] - action_labels[led_onsets[i], 0] += getattr(ActionLabels, f"{action}_{side}") - - if plot: - plt.clf() - _, axes = plt.subplots(4, 1, sharex=True, sharey=True) - axes[0].plot(back_sensor_signal) - if "/'Back_Sensor'/'0'" in behavioural_data: - axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) - else: - axes[1].plot(behavioural_data["/'ReachCue_LEDs'/'0'"].values) - axes[2].plot(action_labels[:, 0]) - axes[3].plot(action_labels[:, 1]) - plt.plot(action_labels[:, 1]) - plt.show() - - return action_labels - - def draw_slit_thresholds(self, project: str, force: bool = False): - """ - Draw lines on the slits using EasyROI. If ROIs already exist, skip. - - Parameters - ========== - project : str - The DLC project i.e. name/prefix of the camera. - - force : bool - If true, we will draw new lines even if the output file exists. - - """ - # Only needed for this method - import cv2 - import EasyROI - - output = self.processed / f"slit_thresholds_{project}.pickle" - - if output.exists() and not force: - print(self.name, "- slits drawn already.") - return - - # Let's take the average between the first and last frames of the whole session. - videos = [] - - for recording in self.files: - for v, video in enumerate(recording.get("camera_data", [])): - if project in video.stem: - avi = self.interim / video.with_suffix('.avi') - if not avi.exists(): - meta = recording['camera_meta'][v] - ioutils.tdms_to_video( - self.find_file(video, copy=False), - self.find_file(meta), - avi, - ) - if not avi.exists(): - raise PixelsError(f"Path {avi} should exist but doesn't... discuss.") - videos.append(avi.as_posix()) - - if not videos: - raise PixelsError("No videos were found to draw slits on.") - - first_frame = ioutils.load_video_frame(videos[0], 1) - last_duration = ioutils.get_video_dimensions(videos[-1])[2] - last_frame = ioutils.load_video_frame(videos[-1], last_duration - 1) - - average_frame = np.concatenate( - [first_frame[..., None], last_frame[..., None]], - axis=2, - ).mean(axis=2) - average_frame = np.squeeze(average_frame) / 255 - - # Interactively draw ROI - global _roi_helper - if _roi_helper is None: - # Ugly but we can only have one instance of this - _roi_helper = EasyROI.EasyROI(verbose=False) - lines = _roi_helper.draw_line(average_frame, 2) - cv2.destroyAllWindows() # Needed otherwise EasyROI errors - - # Save a copy of the frame with ROIs to PNG file - png = output.with_suffix(".png") - copy = EasyROI.visualize_line(average_frame, lines, color=(255, 0, 0)) - plt.imsave(png, copy, cmap='gray') - - # Save lines to file - with output.open('wb') as fd: - pickle.dump(lines['roi'], fd) - - def inject_slit_crossings(self): - """ - Take the lines drawn from `draw_slit_thresholds` above, get the reach - coordinates from DLC output, identify the timepoints when successful reaches - crossed the lines, and add the `reach_onset` event to those timepoints in the - action labels. Also identify which trials need clearing up, i.e. those with - multiple reaches or have failed motion tracking, and exclude those. - """ - lines = {} - projects = ("LeftCam", "RightCam") - - for project in projects: - line_file = self.processed / f"slit_thresholds_{project}.pickle" - - if not line_file.exists(): - print(self.name, "- Lines not drawn for session.") - return - - with line_file.open("rb") as f: - proj_lines = pickle.load(f) - - lines[project] = { - tt:[pd.Series(p, index=["x", "y"]) for p in points.values()] - for tt, points in proj_lines.items() - } - - action_labels = self.get_action_labels() - event = Events.led_off - - # https://bryceboe.com/2006/10/23/line-segment-intersection-algorithm - def ccw(A, B, C): - return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) - - for tt, action in enumerate( - (ActionLabels.correct_left, ActionLabels.correct_right), - ): - data = {} - trajectories = {} - - for project in projects: - proj_data = self.align_trials( - action, - event, - "motion_tracking", - duration=6, - dlc_project=project, - ) - proj_traj, = check_scorers(proj_data) - data[project] = proj_traj - trajectories[project] = get_reach_trajectories(proj_traj)[0] - - for rec_num, recording in enumerate(self.files): - actions = action_labels[rec_num][:, 0] - events = action_labels[rec_num][:, 1] - trial_starts = np.where(np.bitwise_and(actions, action))[0] - - for t, start in enumerate(trial_starts): - centre = np.where(np.bitwise_and(events[start:start + 6000], event))[0] - if len(centre) == 0: - raise PixelsError('Action labels probably miscalculated') - centre = start + centre[0] - # centre is index of this rec's grasp - onsets = [] - for project, motion in trajectories.items(): - left_hand = motion["left_hand_median"][t][:0] - right_hand = motion["right_hand_median"][t][:0] - pt1, pt2 = lines[project][tt] - - x_l = left_hand.iloc[-10:]["x"].mean() - x_r = right_hand.iloc[-10:]["x"].mean() - hand = right_hand - if project == "LeftCam": - if x_l > x_r: - hand = left_hand - else: - if x_r > x_l: - hand = left_hand - - segments = zip( - hand.iloc[::-1].iterrows(), - hand.iloc[-2::-1].iterrows(), - ) - - for (end, ptend), (start, ptsta) in segments: - if ( - ccw(pt1, ptsta, ptend) != ccw(pt2, ptsta, ptend) and - ccw(pt1, pt2, ptsta) != ccw(pt1, pt2, ptend) - ): - # These lines intersect - #print("x from ", ptsta.x, " to ", ptend.x) - break - #if ptend.y > 300 or ptsta.y > 300: - # assert 0 - onsets.append(start) - - onset = max(onsets) - onset_timepoint = round(centre + (onset * 1000)) - events[onset_timepoint] |= Events.reach_onset - - output = self.processed / recording['action_labels'] - np.save(output, action_labels[rec_num]) - - -global _roi_helper -_roi_helper = None - - -class VisualOnly(Reach): - def _extract_action_labels(self, behavioural_data, plot=False): - behavioural_data, action_labels, led_onsets = self._preprocess_behaviour(behavioural_data) - - for i, trial in enumerate(self.metadata["trials"]): - label = "naive_" + _side_map[trial["spout"]] + "_" - if trial["cue_duration"] > 125: - label += "long" - else: - label += "short" - action_labels[led_onsets[i], 0] += getattr(ActionLabels, label) - - return action_labels - - -def get_reach_velocities(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: - """ - Get the velocity curves for the provided reach trajectories. - """ - results = [] - - for df in dfs: - df = df.copy() - deltas = np.square(df.iloc[1:].values - df.iloc[:-1].values) - # Fill the start with a row of zeros - each value is delta in previous 1 ms - deltas = np.append(np.zeros((1, deltas.shape[1])), deltas, axis=0) - deltas = np.sqrt(deltas[:, ::2] + deltas[:, 1::2]) - df = df.drop([c for c in df.columns if "y" in c], axis=1) - df = df.rename({"x": "delta"}, axis='columns') - df.values[:] = deltas - results.append(df) - - return tuple(results) - - -def get_reach_trajectories(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: - """ - Get the median centre point of the hand coordinates - i.e. for the labels for each - of the four digits and hand centre for both paws. - """ - assert dfs - bodyparts = get_body_parts(dfs[0]) - right_paw = [p for p in bodyparts if p.startswith("right")] - left_paw = [p for p in bodyparts if p.startswith("left")] - - trajectories_l = [] - trajectories_r = [] - - for df in dfs: - per_ses_l = [] - per_ses_r = [] - - # Ugly hack so this function works on single sessions - if "session" not in df.columns.names: - df = pd.concat([df], keys=[0], axis=1) - sessions = [0] - single_session = True - else: - single_session = False - sessions = df.columns.get_level_values("session").unique() - - for s in sessions: - per_trial_l = [] - per_trial_r = [] - - trials = df[s].columns.get_level_values("trial").unique() - for t in trials: - tdf = df[s][t] - left = pd.concat((tdf[p] for p in left_paw), keys=left_paw) - right = pd.concat((tdf[p] for p in right_paw), keys=right_paw) - per_trial_l.append(left.groupby(level=1).median()) - per_trial_r.append(right.groupby(level=1).median()) - - per_ses_l.append(pd.concat(per_trial_l, axis=1, keys=trials)) - per_ses_r.append(pd.concat(per_trial_r, axis=1, keys=trials)) - - trajectories_l.append(pd.concat(per_ses_l, axis=1, keys=sessions)) - trajectories_r.append(pd.concat(per_ses_r, axis=1, keys=sessions)) - - if single_session: - return tuple( - pd.concat( - [trajectories_l[i][0], trajectories_r[i][0]], - axis=1, - keys=["left_hand_median", "right_hand_median"], - ) - for i in range(len(dfs)) - ) - return tuple( - pd.concat( - [trajectories_l[i], trajectories_r[i]], - axis=1, - keys=["left_hand_median", "right_hand_median"], - ) - for i in range(len(dfs)) - ) - - -def check_scorers(*dfs: pd.DataFrame) -> tuple[pd.DataFrame]: - """ - Checks that the scorers are identical for all data in the dataframes. These are the - dataframes as returned from Exp.align_trials for motion_tracking data. - - It returns the dataframes with the scorer index level removed. - """ - scorers = set( - s - for df in dfs - for s in df.columns.get_level_values("scorer").unique() - ) - - assert len(scorers) == 1, scorers - return tuple(df.droplevel("scorer", axis=1) for df in dfs) - - -def get_body_parts(df: pd.DataFrame) -> list[str]: - """ - Get the list of body part labels. - """ - return df.columns.get_level_values("bodyparts").unique() From 3821c055be081a88271e88f0470a558af1d57953 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 11 Jun 2024 12:40:02 +0100 Subject: [PATCH 054/658] add vr data; make sure pd hdf funcs do not throw warning --- pixels/ioutils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index fbc3111..533a165 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -119,6 +119,9 @@ def get_data_files(data_dir, session_name): ) recording['CatGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") recording['CatGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") + recording['vr'] = recording['spike_data'].with_name( + f'{session_name}_vr_synched.pickle' + ) files.append(recording) @@ -291,7 +294,10 @@ def read_hdf5(path): pandas.DataFrame : The dataframe stored within the hdf5 file under the name 'df'. """ - df = pd.read_hdf(path, 'df') + df = pd.read_hdf( + path_or_buf=path, + key='df', + ) return df @@ -308,7 +314,11 @@ def write_hdf5(path, df): Dataframe to save to h5. """ - df.to_hdf(path, 'df', mode='w') + df.to_hdf( + path_or_buf=path, + key='df', + mode='w', + ) print('HDF5 saved to ', path) From 0af5e5a0edddc780f1382a1efde024d0e58ab8bd Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Jun 2024 19:42:33 +0100 Subject: [PATCH 055/658] define conditions and based on which to select events for alignment --- pixels/behaviours/virtual_reality.py | 298 ++++++++++++++++++++------- 1 file changed, 224 insertions(+), 74 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index c714799..d298808 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -2,6 +2,11 @@ This module provides reach task specific operations. """ +# NOTE: for event alignment, we align to the first timepoint when the event +# starts, and the last timepoint before the event ends, i.e., think of an event +# as a train of 0s and 1s, we align to the first 1s and the last 1s of a given +# event. + from __future__ import annotations import pickle @@ -10,7 +15,7 @@ import matplotlib.pyplot as plt import pandas as pd -from vision_in_darkness.session import Outcome, World +from vision_in_darkness.session import Outcome, World, Trial_Type from pixels import Experiment, PixelsError from pixels import signal, ioutils @@ -18,7 +23,6 @@ from common_utils import file_utils -SAMPLE_RATE = 1000 class ActionLabels: """ @@ -53,13 +57,13 @@ class ActionLabels: punished = punished_light | punished_dark # trial type combos - light = miss_light | triggered_light | punished_light | default_light | - auto_light | reinf_light + light = miss_light | triggered_light | punished_light | default_light\ + | auto_light | reinf_light dark = miss_dark | triggered_dark | punished_dark | auto_dark | reinf_dark rewarded_light = triggered_light | default_light | auto_light | reinf_light rewarded_dark = triggered_dark | auto_dark | reinf_dark given_light = default_light | auto_light | reinf_light - given_dark = auto_dark | reinf_dark + given_dark = auto_dark | reinf_dark class Events: @@ -69,30 +73,33 @@ class Events: Events can be added on top of each other. """ # vr events - gray_on = 1 << 0 - gray_off = 1 << 1 - tunnel_on = 1 << 2 - tunnel_off = 1 << 3 - dark_on = 1 << 4 - dark_off = 1 << 5 - punish_on = 1 << 6 - punish_off = 1 << 7 - session_end = 1 << 8 + gray_on = 1 << 1 + gray_off = 1 << 2 + light_on = 1 << 3 + light_off = 1 << 4 + dark_on = 1 << 5 + dark_off = 1 << 6 + punish_on = 1 << 7 + punish_off = 1 << 8 + session_end = 1 << 9 + # NOTE if use this event to mark trial ending, begin of the first trial + # needs to be excluded + trial_end = gray_on | punish_on # positional events - black = 1 << 9 # 0 - 60 cm - wall = 1 << 10 # in between landmarks - landmark1 = 1 << 11 # 110 - 130 cm - landmark2 = 1 << 12 # 190 - 210 cm - landmark3 = 1 << 13 # 270 - 290 cm - landmark4 = 1 << 14 # 350 - 370 cm - landmark5 = 1 << 15 # 430 - 450 cm - reward_zone = 1 << 16 # 460 - 495 cm + black = 1 << 10 # 0 - 60 cm + wall = 1 << 11 # in between landmarks + landmark1 = 1 << 12 # 110 - 130 cm + landmark2 = 1 << 13 # 190 - 210 cm + landmark3 = 1 << 14 # 270 - 290 cm + landmark4 = 1 << 15 # 350 - 370 cm + landmark5 = 1 << 16 # 430 - 450 cm + reward_zone = 1 << 17 # 460 - 495 cm # sensors - valve_open = 1 << 17 - valve_closed = 1 << 18 - licked = 1 << 19 + valve_open = 1 << 18 + valve_closed = 1 << 19 + licked = 1 << 20 #run_start = 1 << 12 #run_stop = 1 << 13 @@ -114,51 +121,111 @@ class Events: class VR(Behaviour): - def _extract_action_labels(self, vr_data): - - # TEMPORARY load synced vr data - cache_dir = "/home/amz/interim/behaviour_cache/temp100/trials/" - vr_data = file_utils.load_pickle( - cache_dir + "20231130_az_WDAN07_upsampled.pickle" - #self.cache_dir + 'trials/' + self.name + '_upsampled.pickle' - ) - + def _extract_action_labels(self, vr, vr_data): # create action label array for actions & events - action_label = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + + # make sure position is not nan + no_nan = (~vr_data.position_in_tunnel.isna()) + # define in gray + in_gray = (vr_data.world_index == World.GRAY) + # define in dark + in_dark = (vr_data.world_index == World.DARK_5)\ + | (vr_data.world_index == World.DARK_2_5)\ + | (vr_data.world_index == World.DARK_FULL)\ + & no_nan + # define in white + in_white = (vr_data.world_index == World.WHITE) + # define in tunnel + in_tunnel = ~in_gray & ~in_white & no_nan + # define in light + in_light = (vr_data.world_index == World.TUNNEL)\ + & no_nan + # define light & dark trials + trial_light = (vr_data.trial_type == Trial_Type.LIGHT) + trial_dark = (vr_data.trial_type == Trial_Type.DARK) print(">> Mapping vr event times...") # get gray_on times, i.e., trial starts - gray_idx = vr_data.world_index[vr_data.world_index == World.GRAY].index + gray_idx = vr_data.world_index[in_gray].index # grays grays = np.where(gray_idx.diff() != 1)[0] - # first frame of gray - gray_on = gray_idx[grays].values + # find time for first frame of gray + gray_on_t = gray_idx[grays] + # find their index in vr data + gray_on = vr_data.index.get_indexer(gray_on_t) action_labels[gray_on, 1] += Events.gray_on - # last frame of gray - gray_off = np.append(gray_idx[grays[1:] - 1], gray_idx[-1]) + # find time for last frame of gray + gray_off_t = np.append(gray_idx[grays[1:] - 1], gray_idx[-1]) + # find their index in vr data + gray_off = vr_data.index.get_indexer(gray_off_t) action_labels[gray_off, 1] += Events.gray_off - # get tunnel_on times, i.e., tunnel starts - tunnel_idx = vr_data.world_index[vr_data.world_index == World.TUNNEL].index - # tunnels - tunnels = np.where(tunnel_idx.diff() != 1)[0] - - # tunnel starts is first frame after gray - tunnel_on = gray_off + 1 - action_labels[tunnel_on, 1] += Events.tunnel_on - - # get usual tunnel_off times - tunnel_off = np.append(tunnel_idx[tunnels[1:] - 1], tunnel_idx[-1]) - action_labels[tunnel_off, 1] += Events.tunnel_off + # get light_on times, i.e., light tunnel starts + #light_idx = vr_data.world_index[in_light].index + + # get data in light tunnel + light_data = vr_data[in_light] + # use light tunnel on as trial starts + trial_starts = np.where(light_data.trial_count.diff() != 0)[0] + # get interval of possible starting position + start_interval = int(vr.meta_item('rand_start_int')) + + # double check if all starting positions make sense + wrong_start = light_data.position_in_tunnel.iloc[trial_starts]\ + % start_interval != 0 + wrong_start_idx = np.where(wrong_start)[0] + if not wrong_start_idx.size == 0: + raise PixelsError(f"Check index {wrong_start_idx} of light_data,\ + \nthey do not have the correct starting position.") + + # get timestamps of when trial starts as light tunnel on + light_on_t = light_data.index[trial_starts] + # get index of when trial starts for action labels + light_on = vr_data.index.get_indexer(light_on_t) + action_labels[light_on, 1] += Events.light_on + + # get light trials + light_trials = vr_data[trial_light & in_light] + + # in light trials, light tunnel off when trials end + L_light_off_bool = (light_trials.trial_count.diff().shift(-1).fillna(1) != 0) + L_light_off_t = light_trials.index[L_light_off_bool] + L_light_off = vr_data.index.get_indexer(L_light_off_t) + action_labels[L_light_off, 1] += Events.light_off + + # NOTE: if dark trial is aborted, light tunnel only turns off once; but + # if it is a reward is dispensed, light tunnel turns off twice + + # get dark trials + dark_trials = vr_data[trial_dark & in_tunnel] + # get when dark starts + dark_on = dark_trials[dark_trials.world_index.diff() < 0] + dark_on_idx = vr_data.index.get_indexer(dark_on.index) + action_labels[dark_on_idx, 1] += Events.dark_on + + # get when light tunnel turns off before dark starts + D_light_off_idx = dark_on_idx - 1 + action_labels[D_light_off_idx, 1] += Events.light_off + + # TODO CONTINUE HERE JUN 26 + # separate dark trials into different chunks: + # with & without dark turned on, cuz some trials were punished even + # before dark on + # for trials went into dark, separate into before & after dark starts + assert 0 + + # get when dark ends + dark_in_light = vr_data[trial_dark & in_light] + + D_light_off_bool = (dark_trials_light.trial_count.diff().shift(-1).fillna(1) != 0) + D_light_off_t = dark_trials_light.index[D_light_off_bool] + D_light_off = vr_data.index.get_indexer(D_light_off_t) # get dark on - # define in dark - in_dark = (vr_data.world_index == World.DARK_5)\ - |(vr_data.world_index == World.DARK_2_5)\ - |(vr_data.world_index == World.DARK_FULL) # get index in dark in_dark_idx = vr_data[in_dark].index # number of dark_on does not match with number of dark trials caused by @@ -167,30 +234,110 @@ def _extract_action_labels(self, vr_data): darks = np.where(in_dark_idx.diff() != 1)[0] # first frame of dark - dark_on = in_dark_idx[darks].values + dark_on_t = in_dark_idx[darks] + dark_on = vr_data.index.get_indexer(dark_on_t) action_labels[dark_on, 1] += Events.dark_on # last frame of dark - dark_off = np.append(in_dark_idx[darks[1:] - 1], in_dark_idx[-1]) + dark_off_t = np.append(in_dark_idx[darks[1:] - 1], in_dark_idx[-1]) + dark_off = vr_data.index.get_indexer(dark_off_t) action_labels[dark_off, 1] += Events.dark_off # map licks - licked_idx = vr_data.lick_count[vr_data.lick_count == 1].index.values + licked_idx = np.where(vr_data.lick_count == 1)[0] action_labels[licked_idx, 1] += Events.licked print(">> Mapping vr action times...") + # map trial types + light_trials = vr_data[vr_data.trial_type == 0] + dark_trials = vr_data[vr_data.trial_type == 1] + + # map pre-reward zone + pre_zone = (vr_data.position_in_tunnel < vr.reward_zone_start) + # get in reward zone index + pre_zone_idx = vr_data[pre_zone].index + # get reward type while in reward zone + pre_reward_type = vr_data.reward_type.loc[pre_zone_idx] + + # default reward light trials + default_light = pre_reward_type.index[pre_reward_type == Outcome.DEFAULT] + default_light_id = light_trials.trial_count.loc[default_light].unique() + for i in default_light_id: + default_light_idx = np.where(vr_data.trial_count == i)[0] + action_labels[default_light_idx, 0] = ActionLabels.default_light + + # punished + punished_idx = vr_data.index[in_white] + # punished light + punished_light = light_trials.reindex(punished_idx).dropna() + punished_light_id = punished_light.trial_count.unique() + for i in punished_light_id: + punished_light_idx = np.where(vr_data.trial_count == i)[0] + action_labels[punished_light_idx, 0] = ActionLabels.punished_light + + # punished dark + punished_dark = dark_trials.reindex(punished_idx).dropna() + punished_dark_id = punished_dark.trial_count.unique() + for i in punished_dark_id: + punished_dark_idx = np.where(vr_data.trial_count == i)[0] + action_labels[punished_dark_idx, 0] = ActionLabels.punished_dark + # map reward zone - in_zone = (vr_data.position_in_tunnel >= self.reward_zone_start)\ - & (vr_data.position_in_tunnel <= self.reward_zone_end) + in_zone = (vr_data.position_in_tunnel >= vr.reward_zone_start)\ + & (vr_data.position_in_tunnel <= vr.reward_zone_end) # get in reward zone index in_zone_idx = vr_data[in_zone].index # get reward type while in reward zone reward_type = vr_data.reward_type.loc[in_zone_idx] + # triggered + triggered_idx = reward_type.index[reward_type == Outcome.TRIGGERED] + # triggered light trials + triggered_light = light_trials.reindex(triggered_idx).dropna() + triggered_light_id = triggered_light.trial_count.unique() + for i in triggered_light_id: + trig_light_idx = np.where(vr_data.trial_count == i)[0] + action_labels[trig_light_idx, 0] = ActionLabels.triggered_light + + # automatically rewarded light trials + auto_light = reward_type.index[reward_type == Outcome.AUTO_LIGHT] + auto_light_id = light_trials.trial_count.loc[auto_light].unique() + for i in auto_light_id: + auto_light_idx = np.where(vr_data.trial_count == i)[0] + action_labels[auto_light_idx, 0] = ActionLabels.auto_light + + # reinforcement reward light trials + reinf_light = reward_type.index[reward_type == Outcome.REINF_LIGHT] + reinf_light_id = light_trials.trial_count.loc[reinf_light].unique() + for i in reinf_light_id: + reinf_light_idx = np.where(vr_data.trial_count == i)[0] + action_labels[reinf_light_idx, 0] = ActionLabels.reinf_light + + # triggered dark trials + triggered_dark = dark_trials.reindex(triggered_idx).dropna() + triggered_dark_id = triggered_dark.trial_count.unique() + for i in triggered_dark_id: + trig_dark_idx = np.where(vr_data.trial_count == i)[0] + action_labels[trig_dark_idx, 0] = ActionLabels.triggered_dark + + # automatically rewarded dark trials + auto_dark = reward_type.index[reward_type == Outcome.AUTO_DARK] + auto_dark_id = dark_trials.trial_count.loc[auto_dark].unique() + for i in auto_dark_id: + auto_dark_idx = np.where(vr_data.trial_count == i)[0] + action_labels[auto_dark_idx, 0] = ActionLabels.auto_dark + + # reinforcement reward dark trials + reinf_dark = reward_type.index[reward_type == Outcome.REINF_DARK] + reinf_dark_id = dark_trials.trial_count.loc[reinf_dark].unique() + for i in reinf_dark_id: + reinf_dark_idx = np.where(vr_data.trial_count == i)[0] + action_labels[reinf_dark_idx, 0] = ActionLabels.reinf_dark + # after reward zone before trial resets - pass_zone = (vr_data.position_in_tunnel > self.reward_zone_end)\ - & (vr_data.position_in_tunnel <= self.tunnel_length) + pass_zone = (vr_data.position_in_tunnel > vr.reward_zone_end)\ + & (vr_data.position_in_tunnel <= vr.tunnel_length) # get passed reward zone index pass_zone_idx = vr_data[pass_zone].index end_reward_type = vr_data.reward_type.loc[pass_zone_idx] @@ -198,28 +345,31 @@ def _extract_action_labels(self, vr_data): # missed light trials miss_light = end_reward_type[end_reward_type == Outcome.ABORTED_LIGHT].index.values + miss_light_id = vr_data.trial_count.loc[miss_light].unique() + for i in miss_light_id: + miss_light_idx = np.where(vr_data.trial_count == i)[0] + action_labels[miss_light_idx, 0] = ActionLabels.miss_light + # missed dark trials miss_dark = end_reward_type[end_reward_type == Outcome.ABORTED_DARK].index.values + miss_dark_id = vr_data.trial_count.loc[miss_dark].unique() + for i in miss_dark_id: + miss_dark_idx = np.where(vr_data.trial_count == i)[0] + action_labels[miss_dark_idx, 0] = ActionLabels.miss_dark - # TODO jun 7 2024 action labels not mapped - + # put pixels timestamps in the third column + action_labels = np.column_stack((action_labels, vr_data.index.values)) return action_labels - def _old_extract_action_labels(self, vr_data, action_labels, plot=False): - - for i, trial in enumerate(self.metadata[rec_num]["trials"]): - side = _side_map[trial["spout"]] - outcome = trial["outcome"] - if outcome in _action_map: - action = _action_map[trial["outcome"]] - action_labels[led_onsets[i], 0] += getattr(ActionLabels, f"{action}_{side}") + def _check_action_labels(self, vr_data, action_labels, plot=True): + # TODO jun 9 2024 make this work, save the plot if plot: plt.clf() - _, axes = plt.subplots(4, 1, sharex=True, sharey=True) + _, axes = plt.subplots(4, 1, sharex=False, sharey=False) axes[0].plot(back_sensor_signal) if "/'Back_Sensor'/'0'" in behavioural_data: axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) From 48b9224e9a803615dc9f0c8c24329700363d27cb Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 27 Jun 2024 17:35:42 +0100 Subject: [PATCH 056/658] make sure basic events are mapped correctly --- pixels/behaviours/virtual_reality.py | 121 ++++++++++++++------------- 1 file changed, 62 insertions(+), 59 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index d298808..3ec0488 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -147,6 +147,7 @@ def _extract_action_labels(self, vr, vr_data): print(">> Mapping vr event times...") + # >>>> gray >>>> # get gray_on times, i.e., trial starts gray_idx = vr_data.world_index[in_gray].index # grays @@ -163,90 +164,92 @@ def _extract_action_labels(self, vr, vr_data): # find their index in vr data gray_off = vr_data.index.get_indexer(gray_off_t) action_labels[gray_off, 1] += Events.gray_off + # <<<< gray <<<< - # get light_on times, i.e., light tunnel starts - #light_idx = vr_data.world_index[in_light].index + # >>>> white >>>> + # get punish_on times + punish_idx = vr_data.world_index[in_white].index + # punishes + punishes = np.where(punish_idx.diff() != 1)[0] - # get data in light tunnel - light_data = vr_data[in_light] - # use light tunnel on as trial starts - trial_starts = np.where(light_data.trial_count.diff() != 0)[0] - # get interval of possible starting position - start_interval = int(vr.meta_item('rand_start_int')) + # find time for first frame of punish + punish_on_t = punish_idx[punishes] + # find their index in vr data + punish_on = vr_data.index.get_indexer(punish_on_t) + action_labels[punish_on, 1] += Events.punish_on - # double check if all starting positions make sense - wrong_start = light_data.position_in_tunnel.iloc[trial_starts]\ - % start_interval != 0 - wrong_start_idx = np.where(wrong_start)[0] - if not wrong_start_idx.size == 0: - raise PixelsError(f"Check index {wrong_start_idx} of light_data,\ - \nthey do not have the correct starting position.") - - # get timestamps of when trial starts as light tunnel on - light_on_t = light_data.index[trial_starts] - # get index of when trial starts for action labels + # find time for last frame of punish + punish_off_t = np.append(punish_idx[punishes[1:] - 1], punish_idx[-1]) + # find their index in vr data + punish_off = vr_data.index.get_indexer(punish_off_t) + action_labels[punish_off, 1] += Events.punish_off + # <<<< white <<<< + + # >>>> light >>>> + # get index of data in light tunnel + light_idx = vr_data[in_light].index + # get where light turns on + lights = np.where(light_idx.diff() != 1)[0] + # get timepoint of when light turns on + light_on_t = light_idx[lights] + # get index of when light turns on light_on = vr_data.index.get_indexer(light_on_t) action_labels[light_on, 1] += Events.light_on - # get light trials - light_trials = vr_data[trial_light & in_light] + # get interval of possible starting position + start_interval = int(vr.meta_item('rand_start_int')) - # in light trials, light tunnel off when trials end - L_light_off_bool = (light_trials.trial_count.diff().shift(-1).fillna(1) != 0) - L_light_off_t = light_trials.index[L_light_off_bool] - L_light_off = vr_data.index.get_indexer(L_light_off_t) - action_labels[L_light_off, 1] += Events.light_off + # find starting position in all light_on + trial_starts = light_on[np.where( + vr_data.iloc[light_on].position_in_tunnel % start_interval == 0 + )[0]] - # NOTE: if dark trial is aborted, light tunnel only turns off once; but - # if it is a reward is dispensed, light tunnel turns off twice + if not trial_starts.size == vr_data.trial_count.max(): + raise PixelsError(f"Number of trials does not equal to\ + \n{vr_data.trial_count.max()}.") + # NOTE: if trial starts at 0, the first position_in_tunnel value will + # NOT be nan - # get dark trials - dark_trials = vr_data[trial_dark & in_tunnel] - # get when dark starts - dark_on = dark_trials[dark_trials.world_index.diff() < 0] - dark_on_idx = vr_data.index.get_indexer(dark_on.index) - action_labels[dark_on_idx, 1] += Events.dark_on - - # get when light tunnel turns off before dark starts - D_light_off_idx = dark_on_idx - 1 - action_labels[D_light_off_idx, 1] += Events.light_off - - # TODO CONTINUE HERE JUN 26 - # separate dark trials into different chunks: - # with & without dark turned on, cuz some trials were punished even - # before dark on - # for trials went into dark, separate into before & after dark starts - assert 0 + # last frame of light + light_off_t = np.append(light_idx[lights[1:] - 1], light_idx[-1]) + light_off = vr_data.index.get_indexer(light_off_t) + action_labels[light_off, 1] += Events.light_off + # <<<< light <<<< - # get when dark ends - dark_in_light = vr_data[trial_dark & in_light] + # NOTE: if dark trial is aborted, light tunnel only turns off once; but + # if it is a reward is dispensed, light tunnel turns off twice - D_light_off_bool = (dark_trials_light.trial_count.diff().shift(-1).fillna(1) != 0) - D_light_off_t = dark_trials_light.index[D_light_off_bool] - D_light_off = vr_data.index.get_indexer(D_light_off_t) + # NOTE: number of dark_on does not match with number of dark trials + # caused by triggering punishment before dark - # get dark on + # >>>> dark >>>> # get index in dark - in_dark_idx = vr_data[in_dark].index - # number of dark_on does not match with number of dark trials caused by - # triggering punishment before dark - # darks - darks = np.where(in_dark_idx.diff() != 1)[0] + dark_idx = vr_data[in_dark].index + darks = np.where(dark_idx.diff() != 1)[0] # first frame of dark - dark_on_t = in_dark_idx[darks] + dark_on_t = dark_idx[darks] dark_on = vr_data.index.get_indexer(dark_on_t) action_labels[dark_on, 1] += Events.dark_on # last frame of dark - dark_off_t = np.append(in_dark_idx[darks[1:] - 1], in_dark_idx[-1]) + dark_off_t = np.append(dark_idx[darks[1:] - 1], dark_idx[-1]) dark_off = vr_data.index.get_indexer(dark_off_t) action_labels[dark_off, 1] += Events.dark_off + # <<<< dark <<<< + + # >>>> session ends >>>> + session_end = vr_data.shape[0] + action_labels[session_end, 1] += Events.session_end + # <<<< session ends <<<< - # map licks + # >>>> licks >>>> licked_idx = np.where(vr_data.lick_count == 1)[0] action_labels[licked_idx, 1] += Events.licked + # <<<< licks <<<< + # TODO jun 27 positional events and valve events needs mapping + assert 0 print(">> Mapping vr action times...") # map trial types From a8b5f42e86fb648504a653df9b53ab87518f5775 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 2 Jul 2024 19:23:02 +0100 Subject: [PATCH 057/658] add todo; move things around --- pixels/behaviours/virtual_reality.py | 31 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 3ec0488..f50203a 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -125,6 +125,9 @@ def _extract_action_labels(self, vr, vr_data): # create action label array for actions & events action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + # >>>> definitions >>>> + # TODO july 2 2024 no need to dropna here since already removed in + # processing # make sure position is not nan no_nan = (~vr_data.position_in_tunnel.isna()) # define in gray @@ -144,9 +147,11 @@ def _extract_action_labels(self, vr, vr_data): # define light & dark trials trial_light = (vr_data.trial_type == Trial_Type.LIGHT) trial_dark = (vr_data.trial_type == Trial_Type.DARK) + # <<<< definitions <<<< print(">> Mapping vr event times...") + assert 0 # >>>> gray >>>> # get gray_on times, i.e., trial starts gray_idx = vr_data.world_index[in_gray].index @@ -168,7 +173,7 @@ def _extract_action_labels(self, vr, vr_data): # >>>> white >>>> # get punish_on times - punish_idx = vr_data.world_index[in_white].index + punish_idx = vr_data[in_white].index # punishes punishes = np.where(punish_idx.diff() != 1)[0] @@ -239,7 +244,7 @@ def _extract_action_labels(self, vr, vr_data): # <<<< dark <<<< # >>>> session ends >>>> - session_end = vr_data.shape[0] + session_end = vr_data.shape[0] - 1 action_labels[session_end, 1] += Events.session_end # <<<< session ends <<<< @@ -249,46 +254,50 @@ def _extract_action_labels(self, vr, vr_data): # <<<< licks <<<< # TODO jun 27 positional events and valve events needs mapping - assert 0 print(">> Mapping vr action times...") + # map trial types - light_trials = vr_data[vr_data.trial_type == 0] - dark_trials = vr_data[vr_data.trial_type == 1] + light_trials = vr_data[trial_light & no_nan] + dark_trials = vr_data[trial_dark & no_nan] # map pre-reward zone pre_zone = (vr_data.position_in_tunnel < vr.reward_zone_start) + # map post-reward zone + post_zone = (vr_data.position_in_tunnel > vr.reward_zone_end) # get in reward zone index pre_zone_idx = vr_data[pre_zone].index # get reward type while in reward zone pre_reward_type = vr_data.reward_type.loc[pre_zone_idx] + # >>>> light default reward >>>> # default reward light trials default_light = pre_reward_type.index[pre_reward_type == Outcome.DEFAULT] default_light_id = light_trials.trial_count.loc[default_light].unique() for i in default_light_id: default_light_idx = np.where(vr_data.trial_count == i)[0] action_labels[default_light_idx, 0] = ActionLabels.default_light + # >>>> light default reward >>>> - # punished - punished_idx = vr_data.index[in_white] + # >>>> punished >>>> # punished light - punished_light = light_trials.reindex(punished_idx).dropna() + punished_light = vr_data[trial_light & no_nan & in_white] punished_light_id = punished_light.trial_count.unique() for i in punished_light_id: punished_light_idx = np.where(vr_data.trial_count == i)[0] action_labels[punished_light_idx, 0] = ActionLabels.punished_light # punished dark - punished_dark = dark_trials.reindex(punished_idx).dropna() + punished_dark = vr_data[trial_dark & no_nan & in_white] punished_dark_id = punished_dark.trial_count.unique() for i in punished_dark_id: punished_dark_idx = np.where(vr_data.trial_count == i)[0] action_labels[punished_dark_idx, 0] = ActionLabels.punished_dark + # <<<< punished <<<< + assert 0 # map reward zone - in_zone = (vr_data.position_in_tunnel >= vr.reward_zone_start)\ - & (vr_data.position_in_tunnel <= vr.reward_zone_end) + in_zone = ~pre_zone & ~post_zone # get in reward zone index in_zone_idx = vr_data[in_zone].index # get reward type while in reward zone From d4f1105d7a5b1a25d7684b91de706447c7021583 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 3 Jul 2024 19:45:03 +0100 Subject: [PATCH 058/658] map reward types with loop --- pixels/behaviours/virtual_reality.py | 233 ++++++++++----------------- 1 file changed, 85 insertions(+), 148 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index f50203a..18408d7 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -37,19 +37,19 @@ class ActionLabels: # TODO jun 7 2024 does the name "action" make sense? # triggered vr trials - miss_light = 1 << 0 - miss_dark = 1 << 1 - triggered_light = 1 << 2 - triggered_dark = 1 << 3 - punished_light = 1 << 4 - punished_dark = 1 << 5 + miss_light = 1 << 0 # 1 + miss_dark = 1 << 1 # 2 + triggered_light = 1 << 2 # 4 + triggered_dark = 1 << 3 # 8 + punished_light = 1 << 4 # 16 + punished_dark = 1 << 5 # 32 # given reward - default_light = 1 << 6 - auto_light = 1 << 7 - auto_dark = 1 << 8 - reinf_light = 1 << 9 - reinf_dark = 1 << 10 + default_light = 1 << 6 # 64 + auto_light = 1 << 7 # 128 + auto_dark = 1 << 8 # 256 + reinf_light = 1 << 9 # 512 + reinf_dark = 1 << 10 # 1024 # combos miss = miss_light | miss_dark @@ -73,15 +73,15 @@ class Events: Events can be added on top of each other. """ # vr events - gray_on = 1 << 1 - gray_off = 1 << 2 - light_on = 1 << 3 - light_off = 1 << 4 - dark_on = 1 << 5 - dark_off = 1 << 6 - punish_on = 1 << 7 - punish_off = 1 << 8 - session_end = 1 << 9 + gray_on = 1 << 1 # 2 + gray_off = 1 << 2 # 4 + light_on = 1 << 3 # 8 + light_off = 1 << 4 # 16 + dark_on = 1 << 5 # 32 + dark_off = 1 << 6 # 64 + punish_on = 1 << 7 # 128 + punish_off = 1 << 8 # 256 + session_end = 1 << 9 # 512 # NOTE if use this event to mark trial ending, begin of the first trial # needs to be excluded trial_end = gray_on | punish_on @@ -97,18 +97,17 @@ class Events: reward_zone = 1 << 17 # 460 - 495 cm # sensors - valve_open = 1 << 18 - valve_closed = 1 << 19 - licked = 1 << 20 + valve_open = 1 << 18 # 262144 + valve_closed = 1 << 19 # 524288 + licked = 1 << 20 # 1048576 #run_start = 1 << 12 #run_stop = 1 << 13 -# convert the trial data into Actions and Events -_action_map = { +# map trial outcome +_outcome_map = { Outcome.ABORTED_DARK: "miss_dark", Outcome.ABORTED_LIGHT: "miss_light", - Outcome.NONE: "miss", Outcome.TRIGGERED: "triggered", Outcome.AUTO_LIGHT: "auto_light", Outcome.DEFAULT: "default_light", @@ -117,7 +116,8 @@ class Events: Outcome.REINF_DARK: "reinf_dark", } - +# function to look up trial type +trial_type_lookup = {v: k for k, v in vars(Trial_Type).items()} class VR(Behaviour): @@ -126,24 +126,18 @@ def _extract_action_labels(self, vr, vr_data): action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.int32) # >>>> definitions >>>> - # TODO july 2 2024 no need to dropna here since already removed in - # processing - # make sure position is not nan - no_nan = (~vr_data.position_in_tunnel.isna()) # define in gray in_gray = (vr_data.world_index == World.GRAY) # define in dark in_dark = (vr_data.world_index == World.DARK_5)\ | (vr_data.world_index == World.DARK_2_5)\ - | (vr_data.world_index == World.DARK_FULL)\ - & no_nan + | (vr_data.world_index == World.DARK_FULL) # define in white in_white = (vr_data.world_index == World.WHITE) # define in tunnel - in_tunnel = ~in_gray & ~in_white & no_nan + in_tunnel = ~in_gray & ~in_white # define in light - in_light = (vr_data.world_index == World.TUNNEL)\ - & no_nan + in_light = (vr_data.world_index == World.TUNNEL) # define light & dark trials trial_light = (vr_data.trial_type == Trial_Type.LIGHT) trial_dark = (vr_data.trial_type == Trial_Type.DARK) @@ -151,7 +145,6 @@ def _extract_action_labels(self, vr, vr_data): print(">> Mapping vr event times...") - assert 0 # >>>> gray >>>> # get gray_on times, i.e., trial starts gray_idx = vr_data.world_index[in_gray].index @@ -257,121 +250,65 @@ def _extract_action_labels(self, vr, vr_data): print(">> Mapping vr action times...") - # map trial types - light_trials = vr_data[trial_light & no_nan] - dark_trials = vr_data[trial_dark & no_nan] - - # map pre-reward zone - pre_zone = (vr_data.position_in_tunnel < vr.reward_zone_start) - # map post-reward zone - post_zone = (vr_data.position_in_tunnel > vr.reward_zone_end) - # get in reward zone index - pre_zone_idx = vr_data[pre_zone].index - # get reward type while in reward zone - pre_reward_type = vr_data.reward_type.loc[pre_zone_idx] - - # >>>> light default reward >>>> - # default reward light trials - default_light = pre_reward_type.index[pre_reward_type == Outcome.DEFAULT] - default_light_id = light_trials.trial_count.loc[default_light].unique() - for i in default_light_id: - default_light_idx = np.where(vr_data.trial_count == i)[0] - action_labels[default_light_idx, 0] = ActionLabels.default_light - # >>>> light default reward >>>> - - # >>>> punished >>>> - # punished light - punished_light = vr_data[trial_light & no_nan & in_white] - punished_light_id = punished_light.trial_count.unique() - for i in punished_light_id: - punished_light_idx = np.where(vr_data.trial_count == i)[0] - action_labels[punished_light_idx, 0] = ActionLabels.punished_light - - # punished dark - punished_dark = vr_data[trial_dark & no_nan & in_white] - punished_dark_id = punished_dark.trial_count.unique() - for i in punished_dark_id: - punished_dark_idx = np.where(vr_data.trial_count == i)[0] - action_labels[punished_dark_idx, 0] = ActionLabels.punished_dark - # <<<< punished <<<< - + # >>>> map reward types >>>> + # get non-zero reward types + reward_not_none = (vr_data.reward_type != Outcome.NONE) + + for t, trial in enumerate(vr_data.trial_count.unique()): + # get current trial + of_trial = (vr_data.trial_count == trial) + # get index of current trial + trial_idx = np.where(of_trial)[0] + # find where is non-zero reward type in current trial + reward_typed = vr_data[of_trial & reward_not_none] + # get trial type of current trial + trial_type = int(vr_data[of_trial].trial_type.unique()) + # get name of trial type in string + trial_type_str = trial_type_lookup.get(trial_type).lower() + + # >>>> punished >>>> + if (reward_typed.size == 0)\ + & (vr_data[of_trial & in_white].size != 0): + # punished outcome + outcome = f"punished_{trial_type_str}" + action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) + # <<<< punished <<<< + else: + # >>>> non punished >>>> + # get non-zero reward type in current trial + reward_type = int(reward_typed.reward_type.unique()) + # double check reward_type is in outcome map + assert (reward_type in _outcome_map) + + """ triggered """ + # catch triggered trials and separate trial types + if reward_type == Outcome.TRIGGERED: + outcome = f"{_outcome_map[reward_type]}_{trial_type_str}" + else: + """ given & aborted """ + outcome = _outcome_map[reward_type] + # label outcome + action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) + # <<<< non punished <<<< + + # >>>> non aborted, valve only >>>> + # if not aborted, map valve open & closed + if reward_type > Outcome.NONE: + # map valve open + valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) + action_labels[valve_open_idx, 1] += Events.valve_open + # map valve closed + valve_closed_idx = vr_data.index.get_indexer( + [reward_typed.index[-1]+1] + ) + action_labels[valve_closed_idx, 1] += Events.valve_closed + # <<<< non aborted, valve only <<<< + # <<<< map reward types <<<< assert 0 - # map reward zone - in_zone = ~pre_zone & ~post_zone - # get in reward zone index - in_zone_idx = vr_data[in_zone].index - # get reward type while in reward zone - reward_type = vr_data.reward_type.loc[in_zone_idx] - - # triggered - triggered_idx = reward_type.index[reward_type == Outcome.TRIGGERED] - # triggered light trials - triggered_light = light_trials.reindex(triggered_idx).dropna() - triggered_light_id = triggered_light.trial_count.unique() - for i in triggered_light_id: - trig_light_idx = np.where(vr_data.trial_count == i)[0] - action_labels[trig_light_idx, 0] = ActionLabels.triggered_light - - # automatically rewarded light trials - auto_light = reward_type.index[reward_type == Outcome.AUTO_LIGHT] - auto_light_id = light_trials.trial_count.loc[auto_light].unique() - for i in auto_light_id: - auto_light_idx = np.where(vr_data.trial_count == i)[0] - action_labels[auto_light_idx, 0] = ActionLabels.auto_light - - # reinforcement reward light trials - reinf_light = reward_type.index[reward_type == Outcome.REINF_LIGHT] - reinf_light_id = light_trials.trial_count.loc[reinf_light].unique() - for i in reinf_light_id: - reinf_light_idx = np.where(vr_data.trial_count == i)[0] - action_labels[reinf_light_idx, 0] = ActionLabels.reinf_light - - # triggered dark trials - triggered_dark = dark_trials.reindex(triggered_idx).dropna() - triggered_dark_id = triggered_dark.trial_count.unique() - for i in triggered_dark_id: - trig_dark_idx = np.where(vr_data.trial_count == i)[0] - action_labels[trig_dark_idx, 0] = ActionLabels.triggered_dark - - # automatically rewarded dark trials - auto_dark = reward_type.index[reward_type == Outcome.AUTO_DARK] - auto_dark_id = dark_trials.trial_count.loc[auto_dark].unique() - for i in auto_dark_id: - auto_dark_idx = np.where(vr_data.trial_count == i)[0] - action_labels[auto_dark_idx, 0] = ActionLabels.auto_dark - - # reinforcement reward dark trials - reinf_dark = reward_type.index[reward_type == Outcome.REINF_DARK] - reinf_dark_id = dark_trials.trial_count.loc[reinf_dark].unique() - for i in reinf_dark_id: - reinf_dark_idx = np.where(vr_data.trial_count == i)[0] - action_labels[reinf_dark_idx, 0] = ActionLabels.reinf_dark - - # after reward zone before trial resets - pass_zone = (vr_data.position_in_tunnel > vr.reward_zone_end)\ - & (vr_data.position_in_tunnel <= vr.tunnel_length) - # get passed reward zone index - pass_zone_idx = vr_data[pass_zone].index - end_reward_type = vr_data.reward_type.loc[pass_zone_idx] - - # missed light trials - miss_light = end_reward_type[end_reward_type == - Outcome.ABORTED_LIGHT].index.values - miss_light_id = vr_data.trial_count.loc[miss_light].unique() - for i in miss_light_id: - miss_light_idx = np.where(vr_data.trial_count == i)[0] - action_labels[miss_light_idx, 0] = ActionLabels.miss_light - - # missed dark trials - miss_dark = end_reward_type[end_reward_type == - Outcome.ABORTED_DARK].index.values - miss_dark_id = vr_data.trial_count.loc[miss_dark].unique() - for i in miss_dark_id: - miss_dark_idx = np.where(vr_data.trial_count == i)[0] - action_labels[miss_dark_idx, 0] = ActionLabels.miss_dark # put pixels timestamps in the third column action_labels = np.column_stack((action_labels, vr_data.index.values)) + assert 0 return action_labels From 295eaafa400ac76cbe89efa73eb011e3c1c5e395 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 3 Jul 2024 19:55:57 +0100 Subject: [PATCH 059/658] remove assertion error --- pixels/behaviours/virtual_reality.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 18408d7..942fca5 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -304,11 +304,9 @@ def _extract_action_labels(self, vr, vr_data): action_labels[valve_closed_idx, 1] += Events.valve_closed # <<<< non aborted, valve only <<<< # <<<< map reward types <<<< - assert 0 # put pixels timestamps in the third column action_labels = np.column_stack((action_labels, vr_data.index.values)) - assert 0 return action_labels From a310464851a55a23e4ce866bd18aa198466ebe32 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 3 Jul 2024 19:56:14 +0100 Subject: [PATCH 060/658] add multiprocessing for resampling --- pixels/signal.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/pixels/signal.py b/pixels/signal.py index 0d8cba1..65078d6 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -6,6 +6,9 @@ import time from pathlib import Path +import multiprocessing as mp +from joblib import Parallel, delayed + import cv2 import numpy as np import matplotlib.pyplot as plt @@ -76,6 +79,28 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): chunks = int(np.ceil(size_bytes / 5368709120)) chunk_size = int(np.ceil(cols / chunks)) + # get index & chunk data + #chunk_indices = [(i, min(i + chunk_size, cols)) for i in range(0, cols, chunk_size)] + #chunks_data = [array[:, start:end] for start, end in chunk_indices] + # get number of processes/jobs + #n_processes = mp.cpu_count() - 2 + # initiate a mp pool + #pool = mp.Pool(n_processes) + ## does resample for each chunk + #results = pool.starmap( + # _resample_chunk, + # [(chunk, up, down, poly, padtype) for chunk in chunks_data], + #) + + ## stop adding task to pool + #pool.close() + ## wait till all tasks in pool completed + #pool.join() + #results = Parallel(n_jobs=-1)( + # delayed(_resample_chunk)(chunk, up, down, poly, padtype) for chunk in chunks_data + #) + #print(">> mapped chunk data to pool...") + if chunks > 1: print(f" 0%", end="\r") current = 0 @@ -100,10 +125,26 @@ def resample(array, from_hz, to_hz, poly=True, padtype=None): new_data.append(result) current += chunk_size print(f" {100 * current / cols:.1f}%", end="\r") + #new_data = np.concatenate(results, axis=1).squeeze() return np.concatenate(new_data, axis=1).squeeze()#.astype(np.int16) + #return new_data +def _resample_chunk(chunk_data, up, down, poly, padtype): + if poly: + result = scipy.signal.resample_poly( + chunk_data, up, down, axis=0, padtype=padtype or 'minimum', + ) + else: + samp_num = int(np.ceil( + chunk_data.shape[0] * (up / down) + )) + result = scipy.signal.resample( + chunk_data, samp_num, axis=0 + ) + return result + def binarise(data): """ This normalises an array to between 0 and 1 and then makes all values below 0.5 @@ -197,7 +238,7 @@ def find_sync_lag(array1, array2, plot=False): if plot.exists(): plot = plot.with_name(plot.stem + '_' + time.strftime('%y%m%d-%H%M%S') + '.png') fig, axes = plt.subplots(nrows=2, ncols=1) - plot_length = min(length, 5000) + plot_length = min(length, 30000) if lag >= 0: axes[0].plot(array1[lag:lag + plot_length]) axes[1].plot(array2[:plot_length]) From 190fa779c70e01463ec4e9f208f8b74282074ef9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 3 Jul 2024 19:57:42 +0100 Subject: [PATCH 061/658] make si joblib global; start trial alignment for vr --- pixels/behaviours/base.py | 173 +++++++++++++++++++++++++------------- 1 file changed, 113 insertions(+), 60 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 89dec52..e525b93 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -31,6 +31,7 @@ import spikeinterface.sorters as ss import spikeinterface.curation as sc import spikeinterface.exporters as sexp +import spikeinterface.preprocessing as spre import spikeinterface.postprocessing as spost from scipy import interpolate from tables import HDF5ExtError @@ -46,6 +47,15 @@ np.random.seed(BEHAVIOUR_HZ) +# set si job_kwargs +job_kwargs = dict( + n_jobs=-1, + chunk_duration='1s', + progress_bar=True, +) + +si.set_global_job_kwargs(**job_kwargs) + def _cacheable(method): """ @@ -189,11 +199,11 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self.drop_data() self.spike_meta = [ - ioutils.read_meta(self.find_file(f['spike_meta'], copy=True)) for f in self.files + ioutils.read_meta(self.find_file(f['spike_meta'], copy=False)) for f in self.files + ] + self.lfp_meta = [ + ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files ] - #self.lfp_meta = [ - # ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files - #] # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache self.set_cache(bool(int(os.environ.get("PIXELS_CACHE", 1)))) @@ -202,6 +212,9 @@ def drop_data(self): """ Clear attributes that store data to clear some memory. """ + # assume each pixels session only has one behaviour session, no matter + # number of probes + #self._action_labels = None self._action_labels = [None] * len(self.files) self._behavioural_data = [None] * len(self.files) self._spike_data = [None] * len(self.files) @@ -241,6 +254,9 @@ def get_probe_depth(self): Load probe depth in um from file if it has been recorded. """ for stream_num, depth in enumerate(self._probe_depths): + # TODO jun 12 2024 skip stream 1 for now + if stream_num > 0: + continue if depth is None: try: depth_file = self.processed / 'depth.txt' @@ -249,7 +265,9 @@ def get_probe_depth(self): fd.readlines()][0] except: depth_file = self.processed / self.files[stream_num]["depth_info"] - self._probe_depths[stream_num] = json.load(open(depth_file, mode="r"))["clustering"] + #self._probe_depths[stream_num] = json.load(open(depth_file, mode="r"))["clustering"] + self._probe_depths[stream_num] = json.load( + open(depth_file, mode="r"))["manipulator"] else: msg = f": Can't load probe depth: please add it in um to\ \nprocessed/{self.name}/depth.txt, or save it with other depth related\ @@ -611,49 +629,54 @@ def process_lfp(self): for rec_num, recording in enumerate(self.files): print(f">>>>> Processing LFP for recording {rec_num + 1} of {len(self.files)}") - data_file = self.find_file(recording['lfp_data']) - orig_rate = self.lfp_meta[rec_num]['imSampRate'] - num_chans = self.lfp_meta[rec_num]['nSavedChans'] - - print("> Mapping LFP data") - data = ioutils.read_bin(data_file, num_chans) - output = self.processed / recording['lfp_processed'] if output.exists(): continue + data_file = self.find_file(recording['lfp_data']) + orig_rate = int(self.lfp_meta[rec_num]['imSampRate']) + num_chans = int(self.lfp_meta[rec_num]['nSavedChans']) + + print("> Mapping LFP data") + data = se.read_binary(data_file, orig_rate, np.int16, num_chans-1) + #data = ioutils.read_bin(data_file, num_chans) + print("> Performing median subtraction across channels for each timepoint") - subtracted = signal.median_subtraction(data, axis=1) + #subtracted = signal.median_subtraction(data, axis=1) + cmred = spre.common_reference(data) print(f"> Downsampling to {self.sample_rate} Hz") - downsampled = signal.resample(subtracted, orig_rate, self.sample_rate, False) - sync_chan = downsampled[:, -1] - downsampled = downsampled[:, :-1] + #downsampled = signal.resample(subtracted, orig_rate, self.sample_rate, False) + downsampled = spre.resample(cmred, self.sample_rate) + # get traces + downsampled = downsampled.get_traces() - if self._lag[rec_num] is None: - self.sync_data(rec_num, sync_channel=data[:, -1]) - lag_start, lag_end = self._lag[rec_num] + # TODO jun 10 2024 to get sync channel here and find lag? + #sync_chan = downsampled[:, -1] + #downsampled = downsampled[:, :-1] + + #if self._lag[rec_num] is None: + # self.sync_data(rec_num, sync_channel=data[:, -1]) + #lag_start, lag_end = self._lag[rec_num] sd = self.processed / recording['lfp_sd'] if sd.exists(): continue - SDs = [] - for i in range(downsampled.shape[1]): - SDs.append(np.std(downsampled[:, i])) + SDs = np.std(downsampled, axis=0) results = dict( median=np.median(SDs), - SDs=SDs, + SDs=SDs.tolist(), ) print(f"> Saving standard deviation (and their median) of each channel") with open(sd, 'w') as fd: json.dump(results, fd) - if lag_end < 0: - data = data[:lag_end] - if lag_start < 0: - data = data[- lag_start:] - print(f"> Saving median subtracted & downsampled LFP to {output}") + #if lag_end < 0: + # data = data[:lag_end] + #if lag_start < 0: + # data = data[- lag_start:] + #print(f"> Saving median subtracted & downsampled LFP to {output}") # save in .npy format np.save( file=output, @@ -700,18 +723,26 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: #_dir = self.interim if args == None: + #args = f"-no_run_fld\ + # -g=0,9\ + # -t=0,9\ + # -prb=0:1\ + # -prb_miss_ok\ + # -ap\ + # -lf\ + # -apfilter=butter,12,300,9000\ + # -lffilter=butter,12,0.5,300\ + # -xd=2,0,384,6,350,160\ + # -gblcar\ + # -gfix=0.2,0.1,0.02" args = f"-no_run_fld\ - -g=0,9\ - -t=0,9\ - -prb=0:1\ + -g=0\ + -t=0\ + -prb=0\ -prb_miss_ok\ -ap\ - -lf\ - -apfilter=butter,12,300,9000\ - -lffilter=butter,12,0.5,300\ - -xd=2,0,384,6,350,160\ - -gblcar\ - -gfix=0.2,0.1,0.02" + -xd=2,0,384,6,20,15\ + -xid=2,0,384,6,20,15" session_args = f"-dir={self.interim} -run={self.name} -dest={self.interim} " + args print(f"\ncatgt args of {self.name}: \n{session_args}") @@ -787,17 +818,18 @@ def sort_spikes(self, CatGT_app=None, old=False): """ streams = {} # set chunks for spikeinterface operations - job_kwargs = dict( - n_jobs=-3, # -1: num of job equals num of cores - chunk_duration="1s", - progress_bar=True, - ) + #job_kwargs = dict( + # n_jobs=-3, # -1: num of job equals num of cores + # chunk_duration="1s", + # progress_bar=True, + #) #concat_rec, output = self.load_recording() #assert 0 #TODO: jan 3 see if ks can run normally now using load_recording() self.run_catgt(CatGT_app=CatGT_app) + assert 0 for _, files in enumerate(self.files): if not CatGT_app == None: @@ -1637,9 +1669,8 @@ def _get_spike_times(self, remapped=False): return spike_times[0] - def _get_aligned_spike_times( - self, label, event, duration, rate=False, sigma=None, units=None - ): + def _get_aligned_spike_times(self, label, event, duration, rate=False, + sigma=None, units=None, end_event=None): """ Returns spike times for each unit within a given time window around an event. align_trials delegates to this function, and should be used for getting aligned @@ -1656,21 +1687,38 @@ def _get_aligned_spike_times( spikes = self._get_spike_times()[units] # Convert to ms (self.sample_rate) spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate + # get index of spike times in data in sample_rate Hz too + spikes_idx = spikes * self.sample_rate - if rate: + if rate & (type(duration) != str): # pad ends with 1 second extra to remove edge effects from convolution duration += 2 + half = int((self.sample_rate * duration) / 2) + scan_duration = self.sample_rate * 8 - scan_duration = self.sample_rate * 8 - half = int((self.sample_rate * duration) / 2) cursor = 0 # In sample points i = -1 rec_trials = {} + # since each session has one behaviour session, now only one action + # label file + actions = action_labels[:, 0] + events = action_labels[:, 1] + # select trials + selected_trials = np.where(np.bitwise_and(actions, label))[0] + # map starts by event + starts = np.where(np.bitwise_and(events, event))[0] + ends = np.where(np.bitwise_and(events, end_event))[0] + assert 0 + # TODO jul 3 2024 continue here!!!! + + for rec_num in range(len(self.files)): - actions = action_labels[rec_num][:, 0] - events = action_labels[rec_num][:, 1] - trial_starts = np.where(np.bitwise_and(actions, label))[0] + # TODO jun 12 2024 skip other streams for now + if rec_num > 0: + continue + #actions = action_labels[rec_num][:, 0] + #events = action_labels[rec_num][:, 1] # Account for multiple raw data files meta = self.spike_meta[rec_num] @@ -1794,6 +1842,10 @@ def select_units( widths = None for stream_num, info in enumerate(cluster_info): + # TODO jun 12 2024 skip stream 1 for now + if stream_num > 0: + continue + id_key = 'id' if 'id' in info else 'cluster_id' grouping = 'KSLabel' if uncurated else 'group' @@ -1830,7 +1882,7 @@ def _get_neuro_raw(self, kind): raw = [] meta = getattr(self, f"{kind}_meta") for rec_num, recording in enumerate(self.files): - data_file = self.find_file(recording[f'{kind}_data']) + data_file = self.find_file(recording[f'{kind}_data'], copy=False) orig_rate = int(meta[rec_num]['imSampRate']) num_chans = int(meta[rec_num]['nSavedChans']) factor = orig_rate / self.sample_rate @@ -1866,7 +1918,7 @@ def get_lfp_data_raw(self): @_cacheable def align_trials( self, label, event, data='spike_times', raw=False, duration=1, sigma=None, - units=None, dlc_project=None, video_match=None, + units=None, dlc_project=None, video_match=None, end_event=None, ): """ Get trials aligned to an event. This finds all instances of label in the action @@ -1929,7 +1981,7 @@ def align_trials( # we let a dedicated function handle aligning spike times return self._get_aligned_spike_times( label, event, duration, rate=data == "spike_rate", sigma=sigma, - units=units + units=units, end_event=end_event, ) if data == "motion_tracking" and not dlc_project: @@ -2063,6 +2115,7 @@ def align_clips(self, label, event, video_match, duration=1): trial_starts = np.where(np.bitwise_and(actions, label))[0] behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour'])) + assert 0 behavioural_data = behavioural_data["/'CamFrames'/'0'"] behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) behavioural_data.iloc[:len(behav_array)] = np.squeeze(behav_array) @@ -2223,12 +2276,12 @@ def get_spike_waveforms(self, units=None, method='phy'): #TODO: implement spikeinterface waveform extraction elif method == 'spikeinterface': - # set chunks - job_kwargs = dict( - n_jobs=-3, # -1: num of job equals num of cores - chunk_duration="1s", - progress_bar=True, - ) + ## set chunks + #job_kwargs = dict( + # n_jobs=-3, # -1: num of job equals num of cores + # chunk_duration="1s", + # progress_bar=True, + #) recording, _ = self.load_recording() #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure put old code under loop of stream so it does not break! From a0d9e9ff95a03c66cfcc76469ca468c1bf9137f5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 3 Jul 2024 20:11:45 +0100 Subject: [PATCH 062/658] add todo --- pixels/behaviours/base.py | 2 +- pixels/behaviours/virtual_reality.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e525b93..293dae0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1708,9 +1708,9 @@ def _get_aligned_spike_times(self, label, event, duration, rate=False, selected_trials = np.where(np.bitwise_and(actions, label))[0] # map starts by event starts = np.where(np.bitwise_and(events, event))[0] - ends = np.where(np.bitwise_and(events, end_event))[0] assert 0 # TODO jul 3 2024 continue here!!!! + ends = np.where(np.bitwise_and(events, end_event))[0] for rec_num in range(len(self.files)): diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 942fca5..72386e6 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -2,6 +2,7 @@ This module provides reach task specific operations. """ +# TODO july 3 2024 is this description true???? # NOTE: for event alignment, we align to the first timepoint when the event # starts, and the last timepoint before the event ends, i.e., think of an event # as a train of 0s and 1s, we align to the first 1s and the last 1s of a given @@ -84,7 +85,8 @@ class Events: session_end = 1 << 9 # 512 # NOTE if use this event to mark trial ending, begin of the first trial # needs to be excluded - trial_end = gray_on | punish_on + # TODO jul 3 2024 trial start & end is not properly defined + trial_end = gray_on | punish_on | session_end # positional events black = 1 << 10 # 0 - 60 cm From d4fb4ece28f1020bfe0f039f8a72a44a13a72ed7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 4 Jul 2024 17:58:56 +0100 Subject: [PATCH 063/658] label trial starts & ends; make sure to align to first & last on of the event --- pixels/behaviours/virtual_reality.py | 48 ++++++++++++++++------------ 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 72386e6..b7cb1e5 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -2,11 +2,10 @@ This module provides reach task specific operations. """ -# TODO july 3 2024 is this description true???? # NOTE: for event alignment, we align to the first timepoint when the event # starts, and the last timepoint before the event ends, i.e., think of an event # as a train of 0s and 1s, we align to the first 1s and the last 1s of a given -# event. +# event, except for licks since it could only be on or off per frame. from __future__ import annotations @@ -74,6 +73,7 @@ class Events: Events can be added on top of each other. """ # vr events + trial_start = 1 << 0 # 1 gray_on = 1 << 1 # 2 gray_off = 1 << 2 # 4 light_on = 1 << 3 # 8 @@ -82,11 +82,7 @@ class Events: dark_off = 1 << 6 # 64 punish_on = 1 << 7 # 128 punish_off = 1 << 8 # 256 - session_end = 1 << 9 # 512 - # NOTE if use this event to mark trial ending, begin of the first trial - # needs to be excluded - # TODO jul 3 2024 trial start & end is not properly defined - trial_end = gray_on | punish_on | session_end + trial_end = 1 << 9 # 512 # positional events black = 1 << 10 # 0 - 60 cm @@ -148,9 +144,9 @@ def _extract_action_labels(self, vr, vr_data): print(">> Mapping vr event times...") # >>>> gray >>>> - # get gray_on times, i.e., trial starts + # get timestamps of gray gray_idx = vr_data.world_index[in_gray].index - # grays + # get first grays grays = np.where(gray_idx.diff() != 1)[0] # find time for first frame of gray @@ -166,13 +162,13 @@ def _extract_action_labels(self, vr, vr_data): action_labels[gray_off, 1] += Events.gray_off # <<<< gray <<<< - # >>>> white >>>> - # get punish_on times + # >>>> punishment >>>> + # get timestamps of punishment punish_idx = vr_data[in_white].index - # punishes + # get first punishment punishes = np.where(punish_idx.diff() != 1)[0] - # find time for first frame of punish + # find time for first frame of punishment punish_on_t = punish_idx[punishes] # find their index in vr data punish_on = vr_data.index.get_indexer(punish_on_t) @@ -183,7 +179,22 @@ def _extract_action_labels(self, vr, vr_data): # find their index in vr data punish_off = vr_data.index.get_indexer(punish_off_t) action_labels[punish_off, 1] += Events.punish_off - # <<<< white <<<< + # <<<< punishment <<<< + + # >>>> trial ends >>>> + # trial ends right before punishment starts + action_labels[punish_on-1, 1] += Events.trial_end + + # for non punished trials, right before gray on is when trial ends, and + # the last frame of the session + pre_gray_on_idx = np.append(gray_on[1:] - 1, vr_data.shape[0] - 1) + pre_gray_on = vr_data.iloc[pre_gray_on_idx] + # drop punish_off times + no_punished_t = pre_gray_on.drop(punish_off_t).index + # get index of trial ends in non punished trials + no_punished_idx = vr_data.index.get_indexer(no_punished_t) + action_labels[no_punished_idx, 1] += Events.trial_end + # <<<< trial ends <<<< # >>>> light >>>> # get index of data in light tunnel @@ -203,6 +214,8 @@ def _extract_action_labels(self, vr, vr_data): trial_starts = light_on[np.where( vr_data.iloc[light_on].position_in_tunnel % start_interval == 0 )[0]] + # label trial starts + action_labels[trial_starts, 1] += Events.trial_start if not trial_starts.size == vr_data.trial_count.max(): raise PixelsError(f"Number of trials does not equal to\ @@ -238,11 +251,6 @@ def _extract_action_labels(self, vr, vr_data): action_labels[dark_off, 1] += Events.dark_off # <<<< dark <<<< - # >>>> session ends >>>> - session_end = vr_data.shape[0] - 1 - action_labels[session_end, 1] += Events.session_end - # <<<< session ends <<<< - # >>>> licks >>>> licked_idx = np.where(vr_data.lick_count == 1)[0] action_labels[licked_idx, 1] += Events.licked @@ -301,7 +309,7 @@ def _extract_action_labels(self, vr, vr_data): action_labels[valve_open_idx, 1] += Events.valve_open # map valve closed valve_closed_idx = vr_data.index.get_indexer( - [reward_typed.index[-1]+1] + [reward_typed.index[-1]] ) action_labels[valve_closed_idx, 1] += Events.valve_closed # <<<< non aborted, valve only <<<< From db3f9b0b6f54b2e4d28464baaa18c24c9f9160d0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 4 Jul 2024 19:06:31 +0100 Subject: [PATCH 064/658] add todo --- pixels/behaviours/virtual_reality.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index b7cb1e5..aeb3215 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -36,6 +36,8 @@ class ActionLabels: """ # TODO jun 7 2024 does the name "action" make sense? + # TODO jul 4 2024 only label trial type at the first frame of the trial to + # make it easier for alignment??? # triggered vr trials miss_light = 1 << 0 # 1 miss_dark = 1 << 1 # 2 From 2d966079bf372b1f711f3319ee50915faa11ac66 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Jul 2024 18:17:52 +0100 Subject: [PATCH 065/658] add 1d convolution method since filter 1d does not work --- pixels/signal.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/pixels/signal.py b/pixels/signal.py index 65078d6..dad49e3 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -14,7 +14,8 @@ import matplotlib.pyplot as plt import pandas as pd import scipy.signal -from scipy.ndimage import gaussian_filter1d +from scipy.signal.windows import gaussian +from scipy.ndimage import gaussian_filter1d, convolve1d from pixels import ioutils, PixelsError @@ -267,6 +268,20 @@ def median_subtraction(data, axis=0): return data - np.median(data, axis=axis, keepdims=True) +def convolve_1d(times, sigma): + kernel_size = int(sigma * 6) + kernel = gaussian(kernel_size, std=sigma) + assert 0 + # TODO jul 5 2024 check computational neuroscience course work about + # spike rate convolution + n_kernel = kernel / np.sum(kernel) + convolved_df = times.apply(lambda x: convolve1d(x, kernel, mode='reflect')) + + convolved = gaussian_filter1d(times, sigma, axis=0) * 1000 + + return df + + def convolve(times, duration, sigma=None): """ Create a continuous signal from a set of spike times in milliseconds and convolve @@ -289,10 +304,10 @@ def convolve(times, duration, sigma=None): sigma = 50 # turn into array of 0s and 1s - times_arr = np.zeros((int(duration), len(times.columns))) + times_arr = np.zeros((duration.astype(int), len(times.columns))) for i, unit in enumerate(times): u_times = times[unit] + duration / 2 - u_times = u_times[~np.isnan(u_times)].astype(np.int) + u_times = u_times[~np.isnan(u_times)].astype(int) try: times_arr[u_times, i] = 1 except IndexError: From 409ee8da5d39acb1f43c4c5f0eabe28219256918 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Jul 2024 18:19:56 +0100 Subject: [PATCH 066/658] add func to get aligned spike rate for each trial; add todo --- pixels/behaviours/base.py | 207 +++++++++++++++++++++++++++++++++----- 1 file changed, 181 insertions(+), 26 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 293dae0..3b1cced 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1560,6 +1560,8 @@ def get_action_labels(self): Returns the action labels, either from self._action_labels if they have been loaded already, or from file. """ + # TODO jul 5 2024: only one action label for a session, make sure it + # does not error return self._get_processed_data("_action_labels", "action_labels") def get_behavioural_data(self): @@ -1669,8 +1671,9 @@ def _get_spike_times(self, remapped=False): return spike_times[0] - def _get_aligned_spike_times(self, label, event, duration, rate=False, - sigma=None, units=None, end_event=None): + def _get_aligned_spike_times( + self, label, event, duration, rate=False, sigma=None, units=None + ): """ Returns spike times for each unit within a given time window around an event. align_trials delegates to this function, and should be used for getting aligned @@ -1687,38 +1690,21 @@ def _get_aligned_spike_times(self, label, event, duration, rate=False, spikes = self._get_spike_times()[units] # Convert to ms (self.sample_rate) spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate - # get index of spike times in data in sample_rate Hz too - spikes_idx = spikes * self.sample_rate - if rate & (type(duration) != str): + if rate: # pad ends with 1 second extra to remove edge effects from convolution duration += 2 - half = int((self.sample_rate * duration) / 2) - scan_duration = self.sample_rate * 8 + scan_duration = self.sample_rate * 8 + half = int((self.sample_rate * duration) / 2) cursor = 0 # In sample points i = -1 rec_trials = {} - # since each session has one behaviour session, now only one action - # label file - actions = action_labels[:, 0] - events = action_labels[:, 1] - # select trials - selected_trials = np.where(np.bitwise_and(actions, label))[0] - # map starts by event - starts = np.where(np.bitwise_and(events, event))[0] - assert 0 - # TODO jul 3 2024 continue here!!!! - ends = np.where(np.bitwise_and(events, end_event))[0] - - for rec_num in range(len(self.files)): - # TODO jun 12 2024 skip other streams for now - if rec_num > 0: - continue - #actions = action_labels[rec_num][:, 0] - #events = action_labels[rec_num][:, 1] + actions = action_labels[rec_num][:, 0] + events = action_labels[rec_num][:, 1] + trial_starts = np.where(np.bitwise_and(actions, label))[0] # Account for multiple raw data files meta = self.spike_meta[rec_num] @@ -1785,6 +1771,164 @@ def _get_aligned_spike_times(self, label, event, duration, rate=False, return trials + + def _get_aligned_trials( + self, label, event, end_event=None, sigma=None, bin_size=None, + units=None, + ): + """ + Returns spike times for each unit within a given time window around an event. + align_trials delegates to this function, and should be used for getting aligned + data in scripts. + """ + action_labels = self.get_action_labels() + + if units is None: + units = self.select_units() + + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure old code does not break! + #TODO: spike times cannot be indexed by unit ids anymore + spikes = self._get_spike_times()[units] + # Convert to ms (self.sample_rate) + spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate + # get index of spike times in data in sample_rate Hz too + spikes_idx = spikes * self.sample_rate + + # since each session has one behaviour session, now only one action + # label file + actions = action_labels[:, 0] + events = action_labels[:, 1] + # get timestamps index of behaviour in pixels stream (ms) + timestamps = action_labels[:, 2] + + # select frames of wanted trial type + #starts = np.where(np.bitwise_and(actions, label))[0] + trials = np.where(np.bitwise_and(actions, label))[0] + # map starts by event + starts = np.where(np.bitwise_and(events, event))[0] + # map starts by end event + ends = np.where(np.bitwise_and(events, end_event))[0] + + # only take starts from selected trials + selected_starts = np.where(np.isin(trials, starts))[0] + start_t = timestamps[selected_starts] + # only take ends from selected trials + selected_ends = np.where(np.isin(trials, ends))[0] + end_t = timestamps[selected_ends] + + # pad ends with 1 second extra to remove edge effects from convolution + scan_pad = self.sample_rate + scan_starts = start_t - scan_pad + scan_ends = end_t + scan_pad + scan_durations = scan_ends - scan_starts + # TODO jul 5 2024 CONTINUE HERE + + #cursor = 0 # In sample points + #i = -1 + rec_trials = {} + + for rec_num in range(len(self.files)): + # TODO jun 12 2024 skip other streams for now + if rec_num > 0: + continue + + # Account for multiple raw data files + meta = self.spike_meta[rec_num] + samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 + assert samples.is_integer() + milliseconds = samples / 30 + cursor_duration = cursor / 30 + rec_spikes = spikes[ + (cursor_duration <= spikes) & (spikes < (cursor_duration + milliseconds)) + ] - cursor_duration + cursor += samples + + # Account for lag, in case the ephys recording was started before the + # behaviour + if not self._lag[rec_num] == None: + lag_start, _ = self._lag[rec_num] + else: + lag_start = action_labels[0, -1] + + if lag_start < 0: + rec_spikes = rec_spikes + lag_start + + for i, start in enumerate(selected_starts): + # select spike times of current trial + trial_bool = (rec_spikes >= scan_starts[i])\ + & (rec_spikes <= scan_ends[i]) + + trial = rec_spikes[trial_bool] + + tdf = [] + + # initiate binary spike times array for current trial + times = np.zeros((scan_durations[i], len(units))).astype(int) + # use pixels time as spike index + idx = np.arange(scan_starts[i], scan_ends[i]) + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) + + for unit in trial: + # get spike time for unit + u_times = trial[unit].values + # drop nan + u_times = u_times[~np.isnan(u_times)] + + # round spike times to use it as index + u_spike_idx = np.round(u_times).astype(int) + # make sure it does not exceed scan duration + if (u_spike_idx >= scan_ends[i]).any(): + beyonds = np.where(u_spike_idx >= scan_ends[i])[0] + u_spike_idx[beyonds] = idx[-1] + # make sure no double counted + u_spike_idx = np.unique(u_spike_idx) + + # set spiked to 1 + spiked.loc[u_spike_idx, unit] = 1 + + #udf = pd.DataFrame({int(unit): u_times}) + #tdf.append(udf) + + rates = signal.convolve_1d( + times=spiked, + sigma=sigma, + ) + assert 0 + # remove 1s padding from the start and end + rates = tdfc.iloc[scan_pad:\ + -scan_pad].reset_index(inplace=True) + # convert index to datetime index for resampling + rates.index = pd.to_timedelta(rates.index, unit='ms') + # resample to 100ms bin + bin_rates = rates.resample(bin_size).mean() + # use numeric index + bin_rates.index = np.arange(0, len(bin_rates)) + + rec_trials[i] = bin_rates + + if not rec_trials: + return None + + trials = pd.concat(rec_trials, axis=1, names=["trial", "unit"]) + trials = trials.reorder_levels(["unit", "trial"], axis=1) + trials = trials.sort_index(level=0, axis=1) + + assert 0 + # Set index to seconds and remove the padding 1 sec at each end + points = trials.shape[0] + scan_pad + start = (- duration / 2) + (duration / points) + # Having trouble with float values + #timepoints = np.linspace(start, duration / 2, points, dtype=np.float64) + timepoints = list(range(round(start * 1000), int(duration * 1000 / 2) + 1)) + trials['time'] = pd.Series(timepoints, index=trials.index) / 1000 + trials = trials.set_index('time') + trials = trials.iloc[self.sample_rate : - self.sample_rate] + + return trials + def select_units( self, group='good', min_depth=0, max_depth=None, min_spike_width=None, max_spike_width=None, uncurated=False, name=None @@ -1919,6 +2063,7 @@ def get_lfp_data_raw(self): def align_trials( self, label, event, data='spike_times', raw=False, duration=1, sigma=None, units=None, dlc_project=None, video_match=None, end_event=None, + bin_size=None, ): """ Get trials aligned to an event. This finds all instances of label in the action @@ -1972,6 +2117,8 @@ def align_trials( 'lfp', # Raw/downsampled channels from probe (LFP) 'motion_index', # Motion index per ROI from the video 'motion_tracking', # Motion tracking coordinates from DLC + 'trial_rate', # Taking spike times from the whole duration of each + # trial, convolve into spike rate ] if data not in data_options: raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") @@ -1981,7 +2128,15 @@ def align_trials( # we let a dedicated function handle aligning spike times return self._get_aligned_spike_times( label, event, duration, rate=data == "spike_rate", sigma=sigma, - units=units, end_event=end_event, + units=units, + ) + + if data == "trial_rate": + print(f"Aligning {data} to trials.") + # we let a dedicated function handle aligning spike times + return self._get_aligned_trials( + label, event, end_event=end_event, sigma=sigma, + bin_size=bin_size, units=units, ) if data == "motion_tracking" and not dlc_project: From a4b8c43f45c0da0ec95e0a716a8d9e353a448275 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Jul 2024 18:20:27 +0100 Subject: [PATCH 067/658] add start index --- pixels/behaviours/virtual_reality.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index aeb3215..78a3c20 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -271,6 +271,8 @@ def _extract_action_labels(self, vr, vr_data): of_trial = (vr_data.trial_count == trial) # get index of current trial trial_idx = np.where(of_trial)[0] + # get start index of current trial + start_idx = trial_idx[np.isin(trial_idx, trial_starts)] # find where is non-zero reward type in current trial reward_typed = vr_data[of_trial & reward_not_none] # get trial type of current trial @@ -284,6 +286,7 @@ def _extract_action_labels(self, vr, vr_data): # punished outcome outcome = f"punished_{trial_type_str}" action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) + #action_labels[start_idx, 0] = getattr(ActionLabels, outcome) # <<<< punished <<<< else: # >>>> non punished >>>> @@ -301,6 +304,7 @@ def _extract_action_labels(self, vr, vr_data): outcome = _outcome_map[reward_type] # label outcome action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) + #action_labels[start_idx, 0] = getattr(ActionLabels, outcome) # <<<< non punished <<<< # >>>> non aborted, valve only >>>> From f6b163c002f4b31a9bdcdd1a118cc58ffb9655df Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 9 Jul 2024 17:10:38 +0100 Subject: [PATCH 068/658] make sure not the binned data saved to h5 --- pixels/behaviours/base.py | 73 +++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 3b1cced..042282f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1783,12 +1783,15 @@ def _get_aligned_trials( """ action_labels = self.get_action_labels() + # define output path + output_path = self.interim/\ + f'cache/{self.name}_{label}_{units}_fr_for_AL.npy' + if units is None: units = self.select_units() #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! - #TODO: spike times cannot be indexed by unit ids anymore spikes = self._get_spike_times()[units] # Convert to ms (self.sample_rate) spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate @@ -1822,11 +1825,10 @@ def _get_aligned_trials( scan_starts = start_t - scan_pad scan_ends = end_t + scan_pad scan_durations = scan_ends - scan_starts - # TODO jul 5 2024 CONTINUE HERE - #cursor = 0 # In sample points - #i = -1 + cursor = 0 # In sample points rec_trials = {} + bin_trials = {} for rec_num in range(len(self.files)): # TODO jun 12 2024 skip other streams for now @@ -1861,10 +1863,10 @@ def _get_aligned_trials( trial = rec_spikes[trial_bool] - tdf = [] - # initiate binary spike times array for current trial - times = np.zeros((scan_durations[i], len(units))).astype(int) + # NOTE: dtype must be float otherwise would get all 0 when + # passing gaussian kernel + times = np.zeros((scan_durations[i], len(units))).astype(float) # use pixels time as spike index idx = np.arange(scan_starts[i], scan_ends[i]) # make it df, column name being unit id @@ -1888,25 +1890,50 @@ def _get_aligned_trials( # set spiked to 1 spiked.loc[u_spike_idx, unit] = 1 - #udf = pd.DataFrame({int(unit): u_times}) - #tdf.append(udf) - - rates = signal.convolve_1d( + # convolve spike trains into spike rates + rates = signal.convolve_spike_trains( times=spiked, sigma=sigma, ) - assert 0 + # keep the original index + rates.index = spiked.index # remove 1s padding from the start and end - rates = tdfc.iloc[scan_pad:\ - -scan_pad].reset_index(inplace=True) + rates = rates.iloc[scan_pad: -scan_pad] + # reset index to zero at the beginning of the trial + #rates.reset_index(inplace=True, drop=True) + rec_trials[i] = rates + + bin_rates = rates.copy() + # reset index to zero at the beginning of the trial + bin_rates.reset_index(inplace=True, drop=True) # convert index to datetime index for resampling - rates.index = pd.to_timedelta(rates.index, unit='ms') + bin_rates.index = pd.to_timedelta(bin_rates.index, unit='ms') # resample to 100ms bin - bin_rates = rates.resample(bin_size).mean() + bin_rates = bin_rates.resample(bin_size).mean() # use numeric index bin_rates.index = np.arange(0, len(bin_rates)) + bin_trials[i] = bin_rates - rec_trials[i] = bin_rates + # align all trials by index + all_indices = list(set().union( + *[df.index for df in bin_trials.values()]) + ) + + # reindex all trials by the longest trial + dfs = {key: df.reindex(index=all_indices) + for key, df in bin_trials.items()} + + # get output + output = np.stack( + [df.values for df in dfs.values()], + axis=2, + ).T # reshape into trials x units x bins + + assert 0 + # TODO july 9 2024 how to get trials x temporal bin?? + # save output, for alfredo + np.save(output_path, output) + print(f"> Output saved at {output_path}.") if not rec_trials: return None @@ -1915,18 +1942,6 @@ def _get_aligned_trials( trials = trials.reorder_levels(["unit", "trial"], axis=1) trials = trials.sort_index(level=0, axis=1) - assert 0 - # Set index to seconds and remove the padding 1 sec at each end - points = trials.shape[0] - scan_pad - start = (- duration / 2) + (duration / points) - # Having trouble with float values - #timepoints = np.linspace(start, duration / 2, points, dtype=np.float64) - timepoints = list(range(round(start * 1000), int(duration * 1000 / 2) + 1)) - trials['time'] = pd.Series(timepoints, index=trials.index) / 1000 - trials = trials.set_index('time') - trials = trials.iloc[self.sample_rate : - self.sample_rate] - return trials def select_units( From 1a0a6101c36167afdee17d4591cb528cef84d294 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 9 Jul 2024 17:11:29 +0100 Subject: [PATCH 069/658] explicitly define gaussian kernel for spike rate convolution --- pixels/signal.py | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/pixels/signal.py b/pixels/signal.py index dad49e3..3ca0088 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -268,16 +268,42 @@ def median_subtraction(data, axis=0): return data - np.median(data, axis=axis, keepdims=True) -def convolve_1d(times, sigma): - kernel_size = int(sigma * 6) +def convolve_spike_trains(times, sigma=100, size=10): + """ + Convolve spike times data with 1D gaussian kernel to get spike rate. + + Parameters + ------- + times : pandas.DataFrame, time x units + Spike bool of units at each time point of a trial. + Dtype needs to be float, otherwise convolved results will be all 0. + + sigma : float/int, optional + Time in milliseconds of sigma of gaussian kernel to use. + Default: 100 ms. + + size : float/int, optional + Number of sigma for gaussian kernel to cover, i.e., size of the kernel + Default: 10. + + """ + # get kernel size in ms + kernel_size = int(sigma * size) + # get gaussian kernel kernel = gaussian(kernel_size, std=sigma) - assert 0 - # TODO jul 5 2024 check computational neuroscience course work about - # spike rate convolution + # normalise kernel to ensure that the total area under the Gaussian is 1 n_kernel = kernel / np.sum(kernel) - convolved_df = times.apply(lambda x: convolve1d(x, kernel, mode='reflect')) - convolved = gaussian_filter1d(times, sigma, axis=0) * 1000 + # convolve with gaussian + convolved = convolve1d( + input=times.values, + weights=n_kernel, + output=float, + mode='nearest', + axis=0, + ) * 1000 # rescale it to second + + df = pd.DataFrame(convolved, columns=times.columns) return df From 27d800f1fa3cca00d7fb54782b2dbf5f925ee576 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 11 Jul 2024 21:30:47 +0100 Subject: [PATCH 070/658] add position bin to trial alignment; explicitly state bin to be time or position --- pixels/behaviours/base.py | 45 ++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 042282f..8e8d905 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1773,8 +1773,8 @@ def _get_aligned_spike_times( def _get_aligned_trials( - self, label, event, end_event=None, sigma=None, bin_size=None, - units=None, + self, label, event, end_event=None, sigma=None, time_bin=None, + pos_bin=None, units=None, ): """ Returns spike times for each unit within a given time window around an event. @@ -1790,6 +1790,13 @@ def _get_aligned_trials( if units is None: units = self.select_units() + if not pos_bin is None: + vr_dir = self.find_file(self.files[0]['vr']) + with open(vr_dir, 'rb') as f: + vr_data = pickle.load(f) + # get positions + positions = vr_data.position_in_tunnel + #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! spikes = self._get_spike_times()[units] @@ -1860,9 +1867,13 @@ def _get_aligned_trials( # select spike times of current trial trial_bool = (rec_spikes >= scan_starts[i])\ & (rec_spikes <= scan_ends[i]) - trial = rec_spikes[trial_bool] + # get position bin ids for current trial + trial_pos_bool = (positions.index >= start_t[i])\ + & (positions.index <= end_t[i]) + trial_bin_pos = positions[trial_pos_bool] + # initiate binary spike times array for current trial # NOTE: dtype must be float otherwise would get all 0 when # passing gaussian kernel @@ -1903,16 +1914,24 @@ def _get_aligned_trials( #rates.reset_index(inplace=True, drop=True) rec_trials[i] = rates - bin_rates = rates.copy() + bin_trial = rates.copy() + # add position here to bin together + bin_trial['positions'] = positions + # reset index to zero at the beginning of the trial - bin_rates.reset_index(inplace=True, drop=True) + bin_trial.reset_index(inplace=True, drop=True) # convert index to datetime index for resampling - bin_rates.index = pd.to_timedelta(bin_rates.index, unit='ms') + bin_trial.index = pd.to_timedelta(bin_trial.index, unit='ms') # resample to 100ms bin - bin_rates = bin_rates.resample(bin_size).mean() + bin_trial = bin_trial.resample(time_bin).mean() # use numeric index - bin_rates.index = np.arange(0, len(bin_rates)) - bin_trials[i] = bin_rates + bin_trial.index = np.arange(0, len(bin_trial)) + # bin positions and only save the bin index + # NOTE: here position bin index starts at 1, for alfredo + # to make it back to 0-indexing, remove +1 at the end + bin_trial['positions'] = bin_trial['positions'] // pos_bin + 1 + + bin_trials[i] = bin_trial # align all trials by index all_indices = list(set().union( @@ -1929,8 +1948,6 @@ def _get_aligned_trials( axis=2, ).T # reshape into trials x units x bins - assert 0 - # TODO july 9 2024 how to get trials x temporal bin?? # save output, for alfredo np.save(output_path, output) print(f"> Output saved at {output_path}.") @@ -1938,6 +1955,8 @@ def _get_aligned_trials( if not rec_trials: return None + # TODO july 10 2024 shuffle spike times for each unit across + trials = pd.concat(rec_trials, axis=1, names=["trial", "unit"]) trials = trials.reorder_levels(["unit", "trial"], axis=1) trials = trials.sort_index(level=0, axis=1) @@ -2078,7 +2097,7 @@ def get_lfp_data_raw(self): def align_trials( self, label, event, data='spike_times', raw=False, duration=1, sigma=None, units=None, dlc_project=None, video_match=None, end_event=None, - bin_size=None, + time_bin=None, pos_bin=False, ): """ Get trials aligned to an event. This finds all instances of label in the action @@ -2151,7 +2170,7 @@ def align_trials( # we let a dedicated function handle aligning spike times return self._get_aligned_trials( label, event, end_event=end_event, sigma=sigma, - bin_size=bin_size, units=units, + time_bin=time_bin, pos_bin=pos_bin, units=units, ) if data == "motion_tracking" and not dlc_project: From 1adcbbfea8bd2fec99258d39cfc7fbbe8ec60093 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Jul 2024 20:02:59 +0100 Subject: [PATCH 071/658] add documentation --- pixels/behaviours/base.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 8e8d905..7535e47 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1777,9 +1777,13 @@ def _get_aligned_trials( pos_bin=None, units=None, ): """ - Returns spike times for each unit within a given time window around an event. + Returns spike rate for each unit within a trial. align_trials delegates to this function, and should be used for getting aligned data in scripts. + + This function also saves binned data in the format that Alfredo wants: + trials * units * temporal bins (100ms) + """ action_labels = self.get_action_labels() @@ -2140,6 +2144,15 @@ def align_trials( When aligning video or motion index data, use this fnmatch pattern to select videos. + end_event : int | None + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. + + time_bin: str | None + For VR behaviour, size of temporal bin for spike rate data. + + pos_bin: int | None + For VR behaviour, size of positional bin for position data. """ data = data.lower() From fa1fc1fa8081337a7069bb95b0bd49b64d0fac5c Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 30 Sep 2024 18:20:28 +0100 Subject: [PATCH 072/658] add refactored code that needs testing --- pixels/behaviours/virtual_reality.py | 101 +++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 78a3c20..ec9b05b 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -344,3 +344,104 @@ def _check_action_labels(self, vr_data, action_labels, plot=True): plt.show() return action_labels + + ''' + # TODO sep 30 2024: + # refactored code from chatgpt + # needs testing! + def _assign_event_label(self, action_labels, event_times, event_type, column=1): + """ + Helper function to assign event labels to action_labels array. + """ + event_indices = event_times.index + event_on_idx = np.where(event_indices.diff() != 1)[0] + + # Find first and last timepoints for events + event_on_t = event_indices[event_on_idx] + event_off_t = np.append(event_indices[event_on_idx[1:] - 1], event_indices[-1]) + + event_on = event_times.index.get_indexer(event_on_t) + event_off = event_times.index.get_indexer(event_off_t) + + action_labels[event_on, column] += event_type['on'] + action_labels[event_off, column] += event_type['off'] + + return action_labels + + def _map_trial_events(self, action_labels, vr_data, vr): + """ + Maps different trial events like gray, light, dark, and punishments. + """ + # Define event mappings for gray, light, dark, punishments + event_mappings = { + 'gray': {'on': Events.gray_on, 'off': Events.gray_off, + 'condition': vr_data.world_index == World.GRAY}, + 'light': {'on': Events.light_on, 'off': Events.light_off, + 'condition': vr_data.world_index == World.TUNNEL}, + 'dark': {'on': Events.dark_on, 'off': Events.dark_off, + 'condition': (vr_data.world_index == World.DARK_5)\ + | (vr_data.world_index == World.DARK_2_5)\ + | (vr_data.world_index == World.DARK_FULL)}, + 'punish': {'on': Events.punish_on, 'off': Events.punish_off, + 'condition': vr_data.world_index == World.WHITE}, + } + + for event_name, event_type in event_mappings.items(): + event_times = vr_data[event_type['condition']] + action_labels = self._assign_event_label(action_labels, event_times, event_type) + + return action_labels + + def _assign_trial_outcomes(self, action_labels, vr_data, vr): + """ + Assign outcomes for each trial, including rewards and punishments. + """ + for t, trial in enumerate(vr_data.trial_count.unique()): + # Extract trial-specific information + of_trial = (vr_data.trial_count == trial) + trial_idx = np.where(of_trial)[0] + + reward_not_none = (vr_data.reward_type != Outcome.NONE) + reward_typed = vr_data[of_trial & reward_not_none] + trial_type = int(vr_data[of_trial].trial_type.unique()) + trial_type_str = trial_type_lookup.get(trial_type).lower() + + if reward_typed.size == 0\ + and vr_data[of_trial\ + & (vr_data.world_index == World.WHITE)].size != 0: + # Handle punishment case + outcome = f"punished_{trial_type_str}" + else: + reward_type = int(reward_typed.reward_type.unique()) + outcome = _outcome_map.get(reward_type, "unknown") + + if reward_type == Outcome.TRIGGERED: + outcome = f"{outcome}_{trial_type_str}" + + action_labels[trial_idx, 0] = getattr(ActionLabels, outcome, 0) + + if reward_type > Outcome.NONE: + valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) + valve_closed_idx = vr_data.index.get_indexer([reward_typed.index[-1]]) + action_labels[valve_open_idx, 1] += Events.valve_open + action_labels[valve_closed_idx, 1] += Events.valve_closed + + return action_labels + + def _extract_action_labels(self, vr, vr_data): + """ + Extract action labels from VR data and assign events and outcomes. + """ + action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + + # Map events + action_labels = self._map_trial_events(action_labels, vr_data, vr) + + # Assign trial outcomes + action_labels = self._assign_trial_outcomes(action_labels, vr_data, vr) + + # Add timestamps to action labels + action_labels = np.column_stack((action_labels, vr_data.index.values)) + + return action_labels + ''' From 8e79af76076b46023a6413b5c3ba6fd85908c67f Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:46:06 +0000 Subject: [PATCH 073/658] change class names --- pixels/behaviours/virtual_reality.py | 60 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index ec9b05b..a37c8ac 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt import pandas as pd -from vision_in_darkness.session import Outcome, World, Trial_Type +from vision_in_darkness.base import Outcomes, Worlds, Conditions from pixels import Experiment, PixelsError from pixels import signal, ioutils @@ -106,18 +106,18 @@ class Events: # map trial outcome _outcome_map = { - Outcome.ABORTED_DARK: "miss_dark", - Outcome.ABORTED_LIGHT: "miss_light", - Outcome.TRIGGERED: "triggered", - Outcome.AUTO_LIGHT: "auto_light", - Outcome.DEFAULT: "default_light", - Outcome.REINF_LIGHT: "reinf_light", - Outcome.AUTO_DARK: "auto_dark", - Outcome.REINF_DARK: "reinf_dark", + Outcomes.ABORTED_DARK: "miss_dark", + Outcomes.ABORTED_LIGHT: "miss_light", + Outcomes.TRIGGERED: "triggered", + Outcomes.AUTO_LIGHT: "auto_light", + Outcomes.DEFAULT: "default_light", + Outcomes.REINF_LIGHT: "reinf_light", + Outcomes.AUTO_DARK: "auto_dark", + Outcomes.REINF_DARK: "reinf_dark", } # function to look up trial type -trial_type_lookup = {v: k for k, v in vars(Trial_Type).items()} +trial_type_lookup = {v: k for k, v in vars(Conditions).items()} class VR(Behaviour): @@ -127,20 +127,20 @@ def _extract_action_labels(self, vr, vr_data): # >>>> definitions >>>> # define in gray - in_gray = (vr_data.world_index == World.GRAY) + in_gray = (vr_data.world_index == Worlds.GRAY) # define in dark - in_dark = (vr_data.world_index == World.DARK_5)\ - | (vr_data.world_index == World.DARK_2_5)\ - | (vr_data.world_index == World.DARK_FULL) + in_dark = (vr_data.world_index == Worlds.DARK_5)\ + | (vr_data.world_index == Worlds.DARK_2_5)\ + | (vr_data.world_index == Worlds.DARK_FULL) # define in white - in_white = (vr_data.world_index == World.WHITE) + in_white = (vr_data.world_index == Worlds.WHITE) # define in tunnel in_tunnel = ~in_gray & ~in_white # define in light - in_light = (vr_data.world_index == World.TUNNEL) + in_light = (vr_data.world_index == Worlds.TUNNEL) # define light & dark trials - trial_light = (vr_data.trial_type == Trial_Type.LIGHT) - trial_dark = (vr_data.trial_type == Trial_Type.DARK) + trial_light = (vr_data.trial_type == Conditions.LIGHT) + trial_dark = (vr_data.trial_type == Conditions.DARK) # <<<< definitions <<<< print(">> Mapping vr event times...") @@ -297,7 +297,7 @@ def _extract_action_labels(self, vr, vr_data): """ triggered """ # catch triggered trials and separate trial types - if reward_type == Outcome.TRIGGERED: + if reward_type == Outcomes.TRIGGERED: outcome = f"{_outcome_map[reward_type]}_{trial_type_str}" else: """ given & aborted """ @@ -309,7 +309,7 @@ def _extract_action_labels(self, vr, vr_data): # >>>> non aborted, valve only >>>> # if not aborted, map valve open & closed - if reward_type > Outcome.NONE: + if reward_type > Outcomes.NONE: # map valve open valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) action_labels[valve_open_idx, 1] += Events.valve_open @@ -375,15 +375,15 @@ def _map_trial_events(self, action_labels, vr_data, vr): # Define event mappings for gray, light, dark, punishments event_mappings = { 'gray': {'on': Events.gray_on, 'off': Events.gray_off, - 'condition': vr_data.world_index == World.GRAY}, + 'condition': vr_data.world_index == Worlds.GRAY}, 'light': {'on': Events.light_on, 'off': Events.light_off, - 'condition': vr_data.world_index == World.TUNNEL}, + 'condition': vr_data.world_index == Worlds.TUNNEL}, 'dark': {'on': Events.dark_on, 'off': Events.dark_off, - 'condition': (vr_data.world_index == World.DARK_5)\ - | (vr_data.world_index == World.DARK_2_5)\ - | (vr_data.world_index == World.DARK_FULL)}, + 'condition': (vr_data.world_index == Worlds.DARK_5)\ + | (vr_data.world_index == Worlds.DARK_2_5)\ + | (vr_data.world_index == Worlds.DARK_FULL)}, 'punish': {'on': Events.punish_on, 'off': Events.punish_off, - 'condition': vr_data.world_index == World.WHITE}, + 'condition': vr_data.world_index == Worlds.WHITE}, } for event_name, event_type in event_mappings.items(): @@ -401,26 +401,26 @@ def _assign_trial_outcomes(self, action_labels, vr_data, vr): of_trial = (vr_data.trial_count == trial) trial_idx = np.where(of_trial)[0] - reward_not_none = (vr_data.reward_type != Outcome.NONE) + reward_not_none = (vr_data.reward_type != Outcomes.NONE) reward_typed = vr_data[of_trial & reward_not_none] trial_type = int(vr_data[of_trial].trial_type.unique()) trial_type_str = trial_type_lookup.get(trial_type).lower() if reward_typed.size == 0\ and vr_data[of_trial\ - & (vr_data.world_index == World.WHITE)].size != 0: + & (vr_data.world_index == Worlds.WHITE)].size != 0: # Handle punishment case outcome = f"punished_{trial_type_str}" else: reward_type = int(reward_typed.reward_type.unique()) outcome = _outcome_map.get(reward_type, "unknown") - if reward_type == Outcome.TRIGGERED: + if reward_type == Outcomes.TRIGGERED: outcome = f"{outcome}_{trial_type_str}" action_labels[trial_idx, 0] = getattr(ActionLabels, outcome, 0) - if reward_type > Outcome.NONE: + if reward_type > Outcomes.NONE: valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) valve_closed_idx = vr_data.index.get_indexer([reward_typed.index[-1]]) action_labels[valve_open_idx, 1] += Events.valve_open From 7fba8a1ba03a2ea95b203475e814994cb12c644e Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:47:46 +0000 Subject: [PATCH 074/658] add unfinished trials --- pixels/behaviours/virtual_reality.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index a37c8ac..101b36d 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -288,6 +288,17 @@ def _extract_action_labels(self, vr, vr_data): action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) #action_labels[start_idx, 0] = getattr(ActionLabels, outcome) # <<<< punished <<<< + + elif (reward_typed.size == 0)\ + & (vr_data[of_trial & in_white].size == 0): + # >>>> unfinished trial >>>> + # double check it is the last trial + assert (trial == vr_data.trial_count.unique().max()) + assert (vr_data[of_trial].position_in_tunnel.max()\ + < vr.tunnel_reset) + print(f"> trial {trial} is unfinished when session ends, so " + "there is no outcome.") + # <<<< unfinished trial <<<< else: # >>>> non punished >>>> # get non-zero reward type in current trial From 82acc1f456a232a0d70338f17a346093b91eea44 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:51:21 +0000 Subject: [PATCH 075/658] get lick onset from lick_detect and align to that --- pixels/behaviours/virtual_reality.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 101b36d..f3906ea 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -254,17 +254,18 @@ def _extract_action_labels(self, vr, vr_data): # <<<< dark <<<< # >>>> licks >>>> - licked_idx = np.where(vr_data.lick_count == 1)[0] + lick_onsets = np.diff(vr_data.lick_detect, prepend=0) + licked_idx = np.where(lick_onsets == 1)[0] action_labels[licked_idx, 1] += Events.licked # <<<< licks <<<< - # TODO jun 27 positional events and valve events needs mapping + # TODO jun 27 2024 positional events and valve events needs mapping print(">> Mapping vr action times...") # >>>> map reward types >>>> # get non-zero reward types - reward_not_none = (vr_data.reward_type != Outcome.NONE) + reward_not_none = (vr_data.reward_type != Outcomes.NONE) for t, trial in enumerate(vr_data.trial_count.unique()): # get current trial From 4eb8f5344c5a2971e02be6bd1e33511e3aeebb13 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:52:16 +0000 Subject: [PATCH 076/658] correct indentation --- pixels/behaviours/virtual_reality.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index f3906ea..8a005ff 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -361,7 +361,7 @@ def _check_action_labels(self, vr_data, action_labels, plot=True): # TODO sep 30 2024: # refactored code from chatgpt # needs testing! - def _assign_event_label(self, action_labels, event_times, event_type, column=1): + def _assign_event_label(self, action_labels, event_times, event_type, column=1): """ Helper function to assign event labels to action_labels array. """ From 8de7902131e4ec53009986ef7689fe506b4f135b Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:53:02 +0000 Subject: [PATCH 077/658] add compression filter when saving hdf5 --- pixels/ioutils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 533a165..56d0d14 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -318,6 +318,8 @@ def write_hdf5(path, df): path_or_buf=path, key='df', mode='w', + complevel=9, + complib="blosc:lz4hc", ) print('HDF5 saved to ', path) From ac93bdd59509fb3d79fcedb706ee77b76e9298b4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:53:27 +0000 Subject: [PATCH 078/658] ignore lfp data for now, and save action_labels as npz --- pixels/ioutils.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 56d0d14..732af8b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -78,8 +78,8 @@ def get_data_files(data_dir, session_name): recording = {} recording['spike_data'] = original_name(spike_recording) recording['spike_meta'] = original_name(spike_meta[num]) - recording['lfp_data'] = original_name(lfp_data[num]) - recording['lfp_meta'] = original_name(lfp_meta[num]) + #recording['lfp_data'] = original_name(lfp_data[num]) + #recording['lfp_meta'] = original_name(lfp_meta[num]) if behaviour: if len(behaviour) == len(spike_data): @@ -100,27 +100,27 @@ def get_data_files(data_dir, session_name): recording['behaviour'] = None recording['behaviour_processed'] = None - recording['action_labels'] = Path(f'action_labels_{num}.npy') + recording['action_labels'] = Path(f'action_labels_{num}.npz') recording['spike_processed'] = recording['spike_data'].with_name( recording['spike_data'].stem + '_processed.h5' ) recording['spike_rate_processed'] = Path(f'spike_rate_{num}.h5') - recording['lfp_processed'] = recording['lfp_data'].with_name( - recording['lfp_data'].stem + '_processed.npy' - ) - recording['lfp_sd'] = recording['lfp_data'].with_name( - recording['lfp_data'].stem + '_sd.json' - ) - recording['clustered_channels'] = recording['lfp_data'].with_name( - f'channel_clustering_results_{num}.h5' - ) + #recording['lfp_processed'] = recording['lfp_data'].with_name( + # recording['lfp_data'].stem + '_processed.npy' + #) + #recording['lfp_sd'] = recording['lfp_data'].with_name( + # recording['lfp_data'].stem + '_sd.json' + #) + #recording['clustered_channels'] = recording['lfp_data'].with_name( + # f'channel_clustering_results_{num}.h5' + #) recording['depth_info'] = recording['spike_data'].with_name( f'depth_info_{num}.json' ) recording['CatGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") recording['CatGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") recording['vr'] = recording['spike_data'].with_name( - f'{session_name}_vr_synched.pickle' + f'{session_name}_vr_synched.h5' ) files.append(recording) From d3292cee742a5bc41ca53bf6a83f97c6a0a39a43 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 12:54:33 +0000 Subject: [PATCH 079/658] use decimate to downsample --- pixels/signal.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/pixels/signal.py b/pixels/signal.py index 3ca0088..faca2d5 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -20,6 +20,39 @@ from pixels import ioutils, PixelsError +def decimate(array, from_hz, to_hz, ftype="fir"): + """ + Downsample the signal after applying an anti-aliasing filter. + Downsampling factor MUST be an integer, if not, call `resample`. + + Params + === + array : ndarray, Series or similar + The data to be resampled. + + from_hz : int or float + The starting frequency of the data. + + sample_rate : int or float, optional + The resulting sample rate. + + ftype: str or dlit instance + low pass filter type. + Default: fir (finite impulse response) filter + """ + if from_hz % to_hz == 0: + factor = from_hz // to_hz + output = scipy.signal.decimate( + x=array, + q=factor, + ftype=ftype, + ) + else: + output = resample(array, from_hz, to_hz) + + return output + + def resample(array, from_hz, to_hz, poly=True, padtype=None): """ Resample an array from one sampling rate to another. From 4da1ea157ae469e2bdb7bc28e714d2566d561686 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:23:20 +0000 Subject: [PATCH 080/658] increase sampling rate to 2kHz to safely capture all spikes --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7535e47..9e5e905 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -139,7 +139,7 @@ class Behaviour(ABC): """ - sample_rate = 1000 + SAMPLE_RATE = 2000#1000 def __init__(self, name, data_dir, metadata=None, processed_dir=None, interim_dir=None): From e2b31050a7f2fa08ea5ea61ce0829bb7d534897a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:23:51 +0000 Subject: [PATCH 081/658] ignore lfp for now --- pixels/behaviours/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9e5e905..aa68e9c 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -201,9 +201,9 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self.spike_meta = [ ioutils.read_meta(self.find_file(f['spike_meta'], copy=False)) for f in self.files ] - self.lfp_meta = [ - ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files - ] + #self.lfp_meta = [ + # ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files + #] # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache self.set_cache(bool(int(os.environ.get("PIXELS_CACHE", 1)))) From 6ccf30d5484a0c8eb7b0ee3735231cce2085b63d Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:24:41 +0000 Subject: [PATCH 082/658] use upper case to represent constant --- pixels/behaviours/base.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index aa68e9c..7ffed8c 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -358,7 +358,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): data_file = self.find_file(recording['behaviour']) behavioural_data = ioutils.read_tdms(data_file, groups=["NpxlSync_Signal"]) behavioural_data = signal.resample( - behavioural_data.values, BEHAVIOUR_HZ, self.sample_rate + behavioural_data.values, BEHAVIOUR_HZ, self.SAMPLE_RATE ) if sync_channel is None: @@ -368,7 +368,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): sync_channel = ioutils.read_bin(data_file, num_chans, channel=384) orig_rate = int(self.lfp_meta[rec_num]['imSampRate']) #sync_channel = sync_channel[:120 * orig_rate * 2] # 2 mins, rec Hz, back/forward - sync_channel = signal.resample(sync_channel, orig_rate, self.sample_rate) + sync_channel = signal.resample(sync_channel, orig_rate, self.SAMPLE_RATE) behavioural_data = signal.binarise(behavioural_data) sync_channel = signal.binarise(sync_channel) @@ -484,7 +484,7 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): # convert spike times to ms orig_rate = int(self.spike_meta[0]['imSampRate']) - times_ms = times * self.sample_rate / orig_rate + times_ms = times * self.SAMPLE_RATE / orig_rate lag = [None, 'later', 'earlier'] print(f"""\n> {stream_ids[0]} started\r @@ -514,7 +514,7 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): max shift {np.max(remapped_times_ms-times_ms):.2f}ms.""") # convert remappmed times back to its original sample index - remapped_times = np.uint64(remapped_times_ms * orig_rate / self.sample_rate) + remapped_times = np.uint64(remapped_times_ms * orig_rate / self.SAMPLE_RATE) np.save(output, remapped_times) print(f'\n> Spike times remapping output saved to\n {output}.') @@ -548,8 +548,8 @@ def process_behaviour(self): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - print(f"> Downsampling to {self.sample_rate} Hz") - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] @@ -599,8 +599,8 @@ def process_spikes(self): print("> Mapping spike data") data = ioutils.read_bin(data_file, num_chans) - print(f"> Downsampling to {self.sample_rate} Hz") - data = signal.resample(data, orig_rate, self.sample_rate) + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + data = signal.resample(data, orig_rate, self.SAMPLE_RATE) # Ideally we would median subtract before downsampling, but that takes a # very long time and is at risk of memory errors, so let's do it after. @@ -645,9 +645,9 @@ def process_lfp(self): #subtracted = signal.median_subtraction(data, axis=1) cmred = spre.common_reference(data) - print(f"> Downsampling to {self.sample_rate} Hz") - #downsampled = signal.resample(subtracted, orig_rate, self.sample_rate, False) - downsampled = spre.resample(cmred, self.sample_rate) + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + #downsampled = signal.resample(subtracted, orig_rate, self.SAMPLE_RATE, False) + downsampled = spre.resample(cmred, self.SAMPLE_RATE) # get traces downsampled = downsampled.get_traces() @@ -1196,7 +1196,7 @@ def _align_dlc_coords(self, rec_num, metadata, coords): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] @@ -1386,7 +1386,7 @@ def process_motion_index(self, video_match): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] trigger = signal.binarise(behavioural_data["/'CamFrames'/'0'"]).values @@ -1456,14 +1456,14 @@ def add_motion_index_action_label( action_labels = self.get_action_labels() motion_indexes = self.get_motion_index_data() - scan_duration = self.sample_rate * 10 + scan_duration = self.SAMPLE_RATE * 10 half = scan_duration // 2 # We take 200 ms before the action begins as a short baseline period for each # trial. The smallest standard deviation of all SDs of these baseline periods is # used as a threshold to identify "clean" trials (`clean_threshold` below). # Non-clean trials are trials where TODO - short_pre = int(0.2 * self.sample_rate) + short_pre = int(0.2 * self.SAMPLE_RATE) for rec_num, recording in enumerate(self.files): # Only recs with camera_data will have motion indexes From d77e9e3636aca6f9fffa7b329386ff332d0a212e Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:24:58 +0000 Subject: [PATCH 083/658] allows .npz too --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7ffed8c..c534bdc 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1545,7 +1545,7 @@ def _get_processed_data(self, attr, key): if key in recording: file_path = self.processed / recording[key] if file_path.exists(): - if file_path.suffix == '.npy': + if re.search(r'\.np[yz]$', file_path.suffix): saved[rec_num] = np.load(file_path) elif file_path.suffix == '.h5': saved[rec_num] = ioutils.read_hdf5(file_path) From 8f78ea37d054fca3f83ed499fa11dc9fcbc27a4f Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:25:32 +0000 Subject: [PATCH 084/658] use spikeinterface to get spike times --- pixels/behaviours/base.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c534bdc..7ee6d14 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1624,6 +1624,37 @@ def get_lfp_data(self): """ return self._get_processed_data("_lfp_data", "lfp_processed") + + def _get_si_spike_times(self, sa_dir): + """ + get spike times in second with spikeinterface + """ + spike_times = self._spike_times_data + + for stream_num, stream in enumerate(range(len(spike_times))): + # load sorting analyser + sa = si.load_sorting_analyzer(sa_dir) + + times = {} + # get spike train + for i, unit_id in enumerate(sa.unit_ids): + unit_times = sa.sorting.get_unit_spike_train( + unit_id=unit_id, + return_times=False, + ) + times[unit_id] = pd.Series(unit_times) + # concatenate units + spike_times[stream_num] = pd.concat( + objs=times, + axis=1, + names="unit", + ) + # Convert to time into sample rate index + spike_times[stream_num] /= int(self.spike_meta[0]['imSampRate'])\ + / self.SAMPLE_RATE + + return spike_times[0] # NOTE: only deal with one stream for now + def _get_spike_times(self, remapped=False): """ Returns the sorted spike times. From 1fa6bbfd9de3e8a182d63534e2f54d4ee23918ea Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:25:58 +0000 Subject: [PATCH 085/658] convert spike times into SAMPLE_RATE scale before returning it --- pixels/behaviours/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7ee6d14..7f1c782 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1699,6 +1699,9 @@ def _get_spike_times(self, remapped=False): by_clust[c] = pd.Series(uniques) spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) + # Convert to time into sample rate index + spike_times[stream_num] /= int(self.spike_meta[0]['imSampRate'])\ + / self.SAMPLE_RATE return spike_times[0] From 076926fbcf652760e7cc29d9d15a0cd27dda31b0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:26:33 +0000 Subject: [PATCH 086/658] use upper case for constant; put scale convertion while getting spike times --- pixels/behaviours/base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7f1c782..121a52b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1722,15 +1722,13 @@ def _get_aligned_spike_times( #make sure old code does not break! #TODO: spike times cannot be indexed by unit ids anymore spikes = self._get_spike_times()[units] - # Convert to ms (self.sample_rate) - spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate if rate: # pad ends with 1 second extra to remove edge effects from convolution duration += 2 - scan_duration = self.sample_rate * 8 - half = int((self.sample_rate * duration) / 2) + scan_duration = self.SAMPLE_RATE * 8 + half = int((self.SAMPLE_RATE * duration) / 2) cursor = 0 # In sample points i = -1 rec_trials = {} @@ -1801,7 +1799,7 @@ def _get_aligned_spike_times( timepoints = list(range(round(start * 1000), int(duration * 1000 / 2) + 1)) trials['time'] = pd.Series(timepoints, index=trials.index) / 1000 trials = trials.set_index('time') - trials = trials.iloc[self.sample_rate : - self.sample_rate] + trials = trials.iloc[self.SAMPLE_RATE : - self.SAMPLE_RATE] return trials From 51ddb817185d17b76ee8e967243efbae2c1044c7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:28:59 +0000 Subject: [PATCH 087/658] add temp stuff; access only the first action_label since there is only one; synched vr_data now is hdf5 --- pixels/behaviours/base.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 121a52b..3ac9d8b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1817,39 +1817,38 @@ def _get_aligned_trials( trials * units * temporal bins (100ms) """ - action_labels = self.get_action_labels() + action_labels = self.get_action_labels()[0] # define output path output_path = self.interim/\ - f'cache/{self.name}_{label}_{units}_fr_for_AL.npy' + f'cache/{self.name}_{label}_{units}_fr_for_AL.npz' if units is None: units = self.select_units() if not pos_bin is None: vr_dir = self.find_file(self.files[0]['vr']) - with open(vr_dir, 'rb') as f: - vr_data = pickle.load(f) + vr_data = ioutils.read_hdf5(vr_dir) # get positions positions = vr_data.position_in_tunnel #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! - spikes = self._get_spike_times()[units] - # Convert to ms (self.sample_rate) - spikes /= int(self.spike_meta[0]['imSampRate']) / self.sample_rate - # get index of spike times in data in sample_rate Hz too - spikes_idx = spikes * self.sample_rate + #spikes = self._get_spike_times()[units] + sa_dir="/home/amz/synology/arthur/data/npx/interim/20240812_az_VDCN09/ks4/curated_sa.zarr" + spikes = self._get_si_spike_times(sa_dir)[units] + # get index of spike times in data in SAMPLE_RATE Hz too + spikes_idx = spikes * int(self.spike_meta[0]['imSampRate']) # since each session has one behaviour session, now only one action # label file - actions = action_labels[:, 0] - events = action_labels[:, 1] - # get timestamps index of behaviour in pixels stream (ms) - timestamps = action_labels[:, 2] + actions = action_labels["outcome"] + events = action_labels["events"] + # get timestamps index of behaviour in self.SAMPLE_RATE hz, to convert + # it to ms, do timestamps*1000/self.SAMPLE_RATE + timestamps = action_labels["timestamps"] # select frames of wanted trial type - #starts = np.where(np.bitwise_and(actions, label))[0] trials = np.where(np.bitwise_and(actions, label))[0] # map starts by event starts = np.where(np.bitwise_and(events, event))[0] From 8e1600eaa7e56cec8bc37293403e9c3d95ae02a2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:29:57 +0000 Subject: [PATCH 088/658] index back into `trials` to get the actual starting index --- pixels/behaviours/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 3ac9d8b..d8c5a6a 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1856,14 +1856,14 @@ def _get_aligned_trials( ends = np.where(np.bitwise_and(events, end_event))[0] # only take starts from selected trials - selected_starts = np.where(np.isin(trials, starts))[0] + selected_starts = trials[np.where(np.isin(trials, starts))[0]] start_t = timestamps[selected_starts] # only take ends from selected trials - selected_ends = np.where(np.isin(trials, ends))[0] + selected_ends = trials[np.where(np.isin(trials, ends))[0]] end_t = timestamps[selected_ends] # pad ends with 1 second extra to remove edge effects from convolution - scan_pad = self.sample_rate + scan_pad = self.SAMPLE_RATE scan_starts = start_t - scan_pad scan_ends = end_t + scan_pad scan_durations = scan_ends - scan_starts From b1fa2ba0bf3556c5c2c2b7b31a448c05902871cc Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:30:21 +0000 Subject: [PATCH 089/658] lag is the first in timestamps --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d8c5a6a..1e4ebd4 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1893,7 +1893,7 @@ def _get_aligned_trials( if not self._lag[rec_num] == None: lag_start, _ = self._lag[rec_num] else: - lag_start = action_labels[0, -1] + lag_start = timestamps[0] if lag_start < 0: rec_spikes = rec_spikes + lag_start From e8eb93dc2cc6418f375b679d5e9ccb794766c7de Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:34:59 +0000 Subject: [PATCH 090/658] concat dfs, not rec_trials; use spikeinterface to filter units --- pixels/behaviours/base.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1e4ebd4..c04d754 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1992,12 +1992,37 @@ def _get_aligned_trials( # TODO july 10 2024 shuffle spike times for each unit across - trials = pd.concat(rec_trials, axis=1, names=["trial", "unit"]) + trials = pd.concat(dfs, axis=1, names=["trial", "unit"]) trials = trials.reorder_levels(["unit", "trial"], axis=1) trials = trials.sort_index(level=0, axis=1) return trials + def si_select_units(self, sa_dir, min_depth=0, max_depth=None, name=None): + """ + Use spikeinterface sorting analyser to select units. + """ + # load sorting analyser + sa = si.load_sorting_analyzer(sa_dir) + + # get units + unit_ids = sa.unit_ids + + # init units class + selected_units = SelectedUnits() + if name is not None: + selected_units.name = name + + # get coordinates of channel with max. amplitude + max_chan_coords = sa.sorting.get_property("max_chan_coords") + # get depths + depths = max_chan_coords[:, 1] + # select units within depths range + in_range = unit_ids[(depths >= min_depth) & (depths < max_depth)] + selected_units.extend(in_range) + + return selected_units + def select_units( self, group='good', min_depth=0, max_depth=None, min_spike_width=None, max_spike_width=None, uncurated=False, name=None From 18ad46e05249294f84817c2b71c3f0c38dba0665 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 14:35:37 +0000 Subject: [PATCH 091/658] use upper case to represent constant --- pixels/behaviours/base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c04d754..728ba61 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2123,7 +2123,7 @@ def _get_neuro_raw(self, kind): data_file = self.find_file(recording[f'{kind}_data'], copy=False) orig_rate = int(meta[rec_num]['imSampRate']) num_chans = int(meta[rec_num]['nSavedChans']) - factor = orig_rate / self.sample_rate + factor = orig_rate / self.SAMPLE_RATE data = ioutils.read_bin(data_file, num_chans) @@ -2252,7 +2252,7 @@ def align_trials( getter = getattr(self, f"get_{data}_data_raw", None) if not getter: raise PixelsError(f"align_trials: {data} doesn't have a 'raw' option.") - values, sample_rate = getter() + values, SAMPLE_RATE = getter() else: print(f"Aligning {data} data to trials.") @@ -2262,18 +2262,18 @@ def align_trials( values = self.get_motion_index_data(video_match) else: values = getattr(self, f"get_{data}_data")() - sample_rate = self.sample_rate + SAMPLE_RATE = self.SAMPLE_RATE if not values or values[0] is None: raise PixelsError(f"align_trials: Could not get {data} data.") trials = [] # The logic here is that the action labels will always have a sample rate of - # self.sample_rate, whereas our data here may differ. 'duration' is used to scan + # self.SAMPLE_RATE, whereas our data here may differ. 'duration' is used to scan # the action labels, so always give it 5 seconds to scan, then 'half' is used to # index data. - scan_duration = self.sample_rate * 10 - half = (sample_rate * duration) // 2 + scan_duration = self.SAMPLE_RATE * 10 + half = (SAMPLE_RATE * duration) // 2 if isinstance(half, float): assert half.is_integer() # In case duration is a float < 1 half = int(half) @@ -2304,7 +2304,7 @@ def align_trials( print("No event found for an action. If this is OK, ignore this.") continue centre = start + centre[0] - centre = int(centre * sample_rate / self.sample_rate) + centre = int(centre * SAMPLE_RATE / self.SAMPLE_RATE) trial = values[rec_num][centre - half + 1:centre + half + 1] if isinstance(trial, np.ndarray): @@ -2352,8 +2352,8 @@ def align_clips(self, label, event, video_match, duration=1): """ action_labels = self.get_action_labels() - scan_duration = self.sample_rate * 8 - half = int((self.sample_rate * duration) / 2) + scan_duration = self.SAMPLE_RATE * 8 + half = int((self.SAMPLE_RATE * duration) / 2) cursor = 0 # In sample points i = -1 rec_trials = [] @@ -2375,7 +2375,7 @@ def align_clips(self, label, event, video_match, duration=1): behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour'])) assert 0 behavioural_data = behavioural_data["/'CamFrames'/'0'"] - behav_array = signal.resample(behavioural_data.values, 25000, self.sample_rate) + behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array)] = np.squeeze(behav_array) behavioural_data = behavioural_data[:len(behav_array)] trigger = signal.binarise(behavioural_data).values From a27e0ab613c460b15820e88f1d00105376188712 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Dec 2024 15:39:32 +0000 Subject: [PATCH 092/658] allows dynamic sample rate --- pixels/signal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/signal.py b/pixels/signal.py index faca2d5..991b15e 100644 --- a/pixels/signal.py +++ b/pixels/signal.py @@ -301,7 +301,7 @@ def median_subtraction(data, axis=0): return data - np.median(data, axis=axis, keepdims=True) -def convolve_spike_trains(times, sigma=100, size=10): +def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): """ Convolve spike times data with 1D gaussian kernel to get spike rate. @@ -334,7 +334,7 @@ def convolve_spike_trains(times, sigma=100, size=10): output=float, mode='nearest', axis=0, - ) * 1000 # rescale it to second + ) * sample_rate # rescale it to second df = pd.DataFrame(convolved, columns=times.columns) From a777f31afed66cc7c3f1fb6e98a7f513cf55a7a6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:14:23 +0000 Subject: [PATCH 093/658] remove all nan rows in spike times --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 728ba61..43500eb 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1837,8 +1837,8 @@ def _get_aligned_trials( #spikes = self._get_spike_times()[units] sa_dir="/home/amz/synology/arthur/data/npx/interim/20240812_az_VDCN09/ks4/curated_sa.zarr" spikes = self._get_si_spike_times(sa_dir)[units] - # get index of spike times in data in SAMPLE_RATE Hz too - spikes_idx = spikes * int(self.spike_meta[0]['imSampRate']) + # drop rows if all nans + spikes = spikes.dropna(how="all") # since each session has one behaviour session, now only one action # label file From 32e80497bcf690ca58a68c176d5269de3944904f Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:14:44 +0000 Subject: [PATCH 094/658] get original trial ids from raw --- pixels/behaviours/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 43500eb..583ad68 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1862,6 +1862,9 @@ def _get_aligned_trials( selected_ends = trials[np.where(np.isin(trials, ends))[0]] end_t = timestamps[selected_ends] + # use original trial id as trial index + trial_ids = vr_data.iloc[selected_starts].trial_count.unique() + # pad ends with 1 second extra to remove edge effects from convolution scan_pad = self.SAMPLE_RATE scan_starts = start_t - scan_pad From df4ef3dc018738d60c00b902f3165cf2ecffd477 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:15:22 +0000 Subject: [PATCH 095/658] allows more flexibile sample rate scale conversion --- pixels/behaviours/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 583ad68..ca5af69 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1884,10 +1884,13 @@ def _get_aligned_trials( meta = self.spike_meta[rec_num] samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 assert samples.is_integer() - milliseconds = samples / 30 - cursor_duration = cursor / 30 + in_SAMPLE_RATE_scale = (samples * self.SAMPLE_RATE)\ + / int(self.spike_meta[0]['imSampRate']) + cursor_duration = (cursor * self.SAMPLE_RATE)\ + / int(self.spike_meta[0]['imSampRate']) rec_spikes = spikes[ - (cursor_duration <= spikes) & (spikes < (cursor_duration + milliseconds)) + (cursor_duration <= spikes)\ + & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) ] - cursor_duration cursor += samples From 7726d3945279e7c46bf79722d0fed1255f567cb2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:15:57 +0000 Subject: [PATCH 096/658] do not include the last position of trial to keep the same shape as spikes --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index ca5af69..d8d2676 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1912,8 +1912,8 @@ def _get_aligned_trials( # get position bin ids for current trial trial_pos_bool = (positions.index >= start_t[i])\ - & (positions.index <= end_t[i]) - trial_bin_pos = positions[trial_pos_bool] + & (positions.index < end_t[i]) + trial_pos = positions[trial_pos_bool] # initiate binary spike times array for current trial # NOTE: dtype must be float otherwise would get all 0 when From 69a8563f288f4f5bd7ce8d5134fab7c387668506 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:24:29 +0000 Subject: [PATCH 097/658] takes in sample rate while convolving spike rate; reset rates index before putting into dic; get isi to convert index into ms --- pixels/behaviours/base.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d8d2676..f2adf0e 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1946,23 +1946,31 @@ def _get_aligned_trials( rates = signal.convolve_spike_trains( times=spiked, sigma=sigma, + sample_rate=self.SAMPLE_RATE, ) # keep the original index rates.index = spiked.index # remove 1s padding from the start and end rates = rates.iloc[scan_pad: -scan_pad] # reset index to zero at the beginning of the trial - #rates.reset_index(inplace=True, drop=True) - rec_trials[i] = rates + rates.reset_index(inplace=True, drop=True) + # add position here to bin together + rates['positions'] = trial_pos.values + + rec_trials[trial_ids[i]] = rates bin_trial = rates.copy() - # add position here to bin together - bin_trial['positions'] = positions + # get inter-sample-interval, time interval between each sample + # in milliseconds + isi = (1 / self.SAMPLE_RATE) * 1000 # reset index to zero at the beginning of the trial bin_trial.reset_index(inplace=True, drop=True) # convert index to datetime index for resampling - bin_trial.index = pd.to_timedelta(bin_trial.index, unit='ms') + bin_trial.index = pd.to_timedelta( + arg=bin_trial.index * isi, + unit="ms", + ) # resample to 100ms bin bin_trial = bin_trial.resample(time_bin).mean() # use numeric index From 83cea8fa0873ac837fdce230ba39a502dc62d1b6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:27:50 +0000 Subject: [PATCH 098/658] save binned positions in a separate column --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f2adf0e..f4cc066 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1978,7 +1978,7 @@ def _get_aligned_trials( # bin positions and only save the bin index # NOTE: here position bin index starts at 1, for alfredo # to make it back to 0-indexing, remove +1 at the end - bin_trial['positions'] = bin_trial['positions'] // pos_bin + 1 + bin_trial['pos_bin'] = bin_trial['positions'] // pos_bin + 1 bin_trials[i] = bin_trial From 029d5714bcfd624ffbc81197b7755bbaf915761e Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:28:40 +0000 Subject: [PATCH 099/658] save firing rate & binned position in two arrays; concat rec_trials too --- pixels/behaviours/base.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f4cc066..8df666f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1998,17 +1998,34 @@ def _get_aligned_trials( ).T # reshape into trials x units x bins # save output, for alfredo - np.save(output_path, output) + # use label as array key name + arr_to_save = { + "fr": output[:, :-2, :], + "pos": output[:, -2:, :], + } + np.savez_compressed(output_path, **arr_to_save) print(f"> Output saved at {output_path}.") if not rec_trials: return None - # TODO july 10 2024 shuffle spike times for each unit across + # align all trials by index + rec_indices = list(set().union( + *[df.index for df in rec_trials.values()]) + ) + # reindex all trials by the longest trial + raw_dfs = {key: df.reindex(index=rec_indices) + for key, df in rec_trials.items()} - trials = pd.concat(dfs, axis=1, names=["trial", "unit"]) + # TODO dec 13 2024: mixing position in unit level cause performance + # warning during save, fix it + + # TODO july 10 2024 shuffle spike times for each unit across the whole + # recording + + trials = pd.concat(raw_dfs, axis=1, names=["trial", "unit"]) trials = trials.reorder_levels(["unit", "trial"], axis=1) - trials = trials.sort_index(level=0, axis=1) + trials.sort_index(level=0, axis=1) return trials From 2110189801e06ecf2021200b899853c71ef972e3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 11:29:21 +0000 Subject: [PATCH 100/658] remove space --- pixels/behaviours/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 8df666f..ee760f7 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1986,7 +1986,6 @@ def _get_aligned_trials( all_indices = list(set().union( *[df.index for df in bin_trials.values()]) ) - # reindex all trials by the longest trial dfs = {key: df.reindex(index=all_indices) for key, df in bin_trials.items()} From 99c82888a76608979fe36bcb8ea322032fce5480 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 13:49:02 +0000 Subject: [PATCH 101/658] allows more flexible key and mode --- pixels/ioutils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 732af8b..0e9dc6c 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -301,7 +301,7 @@ def read_hdf5(path): return df -def write_hdf5(path, df): +def write_hdf5(path, df, key="df", mode="w"): """ Write a dataframe to an h5 file. @@ -313,12 +313,23 @@ def write_hdf5(path, df): df : pd.DataFrame Dataframe to save to h5. + key : str + identifier for the group in the store. + Default: "df". + + mode : str + mode to open file. + Default: "w" write. + Options: + "a": append, if file does not exists it is created. + "r+": similar to "a" but file must exists. """ df.to_hdf( path_or_buf=path, - key='df', - mode='w', + key=key, + mode=mode, complevel=9, + #complib="bzip2", # slower but higher compression ratio complib="blosc:lz4hc", ) From 2d8d5721aac8a55da6e9572931d159aac84b6027 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 13:50:05 +0000 Subject: [PATCH 102/658] allows to read and write dict contains multiple dfs into one hdf5 file --- pixels/behaviours/base.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index ee760f7..fbaa358 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -92,13 +92,38 @@ def func(*args, **kwargs): df = ioutils.read_hdf5(output) except HDF5ExtError: df = None + except (KeyError, ValueError): + # if key="df" is not found, then use HDFStore to list and read + # all dfs + with pd.HDFStore(output, "r") as store: + # list all keys + keys = store.keys() + # create df as a dictionary to hold all dfs + df = {} + for key in keys: + # read current df + data = store[key] + # remove "/" in key + key_name = key.lstrip("/") + # use key name as dict key + df[key_name] = data else: df = method(*args, **kwargs) output.parent.mkdir(parents=True, exist_ok=True) if df is None: output.touch() else: - ioutils.write_hdf5(output, df) + # allows to save multiple dfs in a dict in one hdf5 file + if isinstance(df, dict): + for name, df in df.items(): + ioutils.write_hdf5( + path=output, + df=df, + key=name, + mode="a", + ) + else: + ioutils.write_hdf5(output, df) return df return func From fde70cf2ab8d559310869e7048227043832528e1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 13:53:44 +0000 Subject: [PATCH 103/658] saves positions in a separate df --- pixels/behaviours/base.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index fbaa358..87c2d1d 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1898,6 +1898,7 @@ def _get_aligned_trials( cursor = 0 # In sample points rec_trials = {} + trial_positions = {} bin_trials = {} for rec_num in range(len(self.files)): @@ -1979,12 +1980,14 @@ def _get_aligned_trials( rates = rates.iloc[scan_pad: -scan_pad] # reset index to zero at the beginning of the trial rates.reset_index(inplace=True, drop=True) - # add position here to bin together - rates['positions'] = trial_pos.values + trial_pos.reset_index(inplace=True, drop=True) rec_trials[trial_ids[i]] = rates + trial_positions[trial_ids[i]] = trial_pos bin_trial = rates.copy() + # add position here to bin together + bin_trial['positions'] = trial_pos.values # get inter-sample-interval, time interval between each sample # in milliseconds @@ -2040,9 +2043,10 @@ def _get_aligned_trials( # reindex all trials by the longest trial raw_dfs = {key: df.reindex(index=rec_indices) for key, df in rec_trials.items()} + raw_pos = {key: df.reindex(index=rec_indices) + for key, df in trial_positions.items()} - # TODO dec 13 2024: mixing position in unit level cause performance - # warning during save, fix it + positions = pd.concat(raw_pos, axis=1, names="trial") # TODO july 10 2024 shuffle spike times for each unit across the whole # recording @@ -2051,7 +2055,7 @@ def _get_aligned_trials( trials = trials.reorder_levels(["unit", "trial"], axis=1) trials.sort_index(level=0, axis=1) - return trials + return {"trials": trials, "positions": positions} def si_select_units(self, sa_dir, min_depth=0, max_depth=None, name=None): """ From 6d93499c5af6a8fd467bf3611dbba67fd4e7fe6b Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 13:58:43 +0000 Subject: [PATCH 104/658] use more meaningful names --- pixels/behaviours/base.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 87c2d1d..378f8a6 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1844,8 +1844,8 @@ def _get_aligned_trials( """ action_labels = self.get_action_labels()[0] - # define output path - output_path = self.interim/\ + # define output path for alfredo + output_al_path = self.interim/\ f'cache/{self.name}_{label}_{units}_fr_for_AL.npz' if units is None: @@ -2014,24 +2014,24 @@ def _get_aligned_trials( all_indices = list(set().union( *[df.index for df in bin_trials.values()]) ) - # reindex all trials by the longest trial - dfs = {key: df.reindex(index=all_indices) + # reindex bin_trial by the longest trial + bin_dfs = {key: df.reindex(index=all_indices) for key, df in bin_trials.items()} - # get output - output = np.stack( - [df.values for df in dfs.values()], + # stack df values into np array + bin_arr = np.stack( + [df.values for df in bin_dfs.values()], axis=2, ).T # reshape into trials x units x bins - # save output, for alfredo + # save bin_arr, for alfredo # use label as array key name arr_to_save = { - "fr": output[:, :-2, :], - "pos": output[:, -2:, :], + "fr": bin_arr[:, :-2, :], + "pos": bin_arr[:, -2:, :], } - np.savez_compressed(output_path, **arr_to_save) - print(f"> Output saved at {output_path}.") + np.savez_compressed(output_al_path, **arr_to_save) + print(f"> Output saved at {output_al_path}.") if not rec_trials: return None From 411b81c79dea67c1577b4535d750ec49a4b4997d Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 18:19:42 +0000 Subject: [PATCH 105/658] add sorting analyser directory --- pixels/ioutils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 0e9dc6c..3281c1c 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -122,6 +122,7 @@ def get_data_files(data_dir, session_name): recording['vr'] = recording['spike_data'].with_name( f'{session_name}_vr_synched.h5' ) + recording['sorting_analyser'] = Path(f'ks4/curated_sa.zarr') files.append(recording) From 1699c3796d582bf3f4cb974ffd7d41c104ffcd56 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 18:19:59 +0000 Subject: [PATCH 106/658] for trial rate, concat firing rate & positions separately --- pixels/experiment.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index 8466449..aa7f297 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -237,6 +237,31 @@ def align_trials(self, *args, units=None, **kwargs): names=["session", "trial", "scorer", "bodyparts", "coords"] ) + if "trial_rate" in kwargs.values(): + frs = {} + positions = {} + for s in trials: + frs[s] = trials[s]["fr"] + positions[s] = trials[s]["positions"] + + frs_df = pd.concat( + frs.values(), + axis=1, + copy=False, + keys=frs.keys(), + names=["session"] + ) + pos_df = pd.concat( + positions.values(), + axis=1, + copy=False, + keys=positions.keys(), + names=["session"] + ) + df = { + "fr": frs_df, + "positions": pos_df, + } else: df = pd.concat( trials.values(), axis=1, copy=False, From 9dd3917132c41824d1e002fb20a5eba1a3550550 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 18:20:39 +0000 Subject: [PATCH 107/658] avoid using `df` in the loop to overwrite output variable --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 378f8a6..a3109df 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -115,10 +115,10 @@ def func(*args, **kwargs): else: # allows to save multiple dfs in a dict in one hdf5 file if isinstance(df, dict): - for name, df in df.items(): + for name, values in df.items(): ioutils.write_hdf5( path=output, - df=df, + df=values, key=name, mode="a", ) From 52cbb3e93bc8a2397f8076c87d01cd3ed98a38a2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 18:23:42 +0000 Subject: [PATCH 108/658] give output more meaningful name --- pixels/behaviours/base.py | 35 +++++------------------------------ 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a3109df..e76b6df 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2051,40 +2051,15 @@ def _get_aligned_trials( # TODO july 10 2024 shuffle spike times for each unit across the whole # recording - trials = pd.concat(raw_dfs, axis=1, names=["trial", "unit"]) - trials = trials.reorder_levels(["unit", "trial"], axis=1) - trials.sort_index(level=0, axis=1) - - return {"trials": trials, "positions": positions} - - def si_select_units(self, sa_dir, min_depth=0, max_depth=None, name=None): - """ - Use spikeinterface sorting analyser to select units. - """ - # load sorting analyser - sa = si.load_sorting_analyzer(sa_dir) - - # get units - unit_ids = sa.unit_ids + fr = pd.concat(raw_dfs, axis=1, names=["trial", "unit"]) + fr = fr.reorder_levels(["unit", "trial"], axis=1) + fr.sort_index(level=0, axis=1) - # init units class - selected_units = SelectedUnits() - if name is not None: - selected_units.name = name - - # get coordinates of channel with max. amplitude - max_chan_coords = sa.sorting.get_property("max_chan_coords") - # get depths - depths = max_chan_coords[:, 1] - # select units within depths range - in_range = unit_ids[(depths >= min_depth) & (depths < max_depth)] - selected_units.extend(in_range) - - return selected_units + return {"fr": fr, "positions": positions} def select_units( self, group='good', min_depth=0, max_depth=None, min_spike_width=None, - max_spike_width=None, uncurated=False, name=None + max_spike_width=None, uncurated=False, name=None, use_si=False, ): """ Select units based on specified criteria. The output of this can be passed to From 2e77cd7fe42a375409fb8107de34e2e1fea812ab Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 16 Dec 2024 18:24:17 +0000 Subject: [PATCH 109/658] put use spikeinterface to get unit as a method in `select_units` --- pixels/behaviours/base.py | 112 +++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 44 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e76b6df..c85a77b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2098,57 +2098,81 @@ def select_units( is the same between uses of the same name. """ - cluster_info = self.get_cluster_info() - selected_units = SelectedUnits() - if name is not None: - selected_units.name = name - - if min_depth is not None or max_depth is not None: - probe_depths = self.get_probe_depth() - - if min_spike_width == 0: - min_spike_width = None - if min_spike_width is not None or max_spike_width is not None: - widths = self.get_spike_widths() + if use_si: + self.sa_dir = self.find_file(self.files[0]["sorting_analyser"]) + # load sorting analyser + sa = si.load_sorting_analyzer(self.sa_dir) + + # get units + unit_ids = sa.unit_ids + + # init units class + selected_units = SelectedUnits() + if name is not None: + selected_units.name = name + + # get coordinates of channel with max. amplitude + max_chan_coords = sa.sorting.get_property("max_chan_coords") + # get depths + depths = max_chan_coords[:, 1] + # select units within depths range + in_range = unit_ids[(depths >= min_depth) & (depths < max_depth)] + selected_units.extend(in_range) + + return selected_units + else: - widths = None + cluster_info = self.get_cluster_info() + selected_units = SelectedUnits() + if name is not None: + selected_units.name = name + + if min_depth is not None or max_depth is not None: + probe_depths = self.get_probe_depth() + + if min_spike_width == 0: + min_spike_width = None + if min_spike_width is not None or max_spike_width is not None: + widths = self.get_spike_widths() + else: + widths = None - for stream_num, info in enumerate(cluster_info): - # TODO jun 12 2024 skip stream 1 for now - if stream_num > 0: - continue + for stream_num, info in enumerate(cluster_info): + # TODO jun 12 2024 skip stream 1 for now + if stream_num > 0: + continue + + id_key = 'id' if 'id' in info else 'cluster_id' + grouping = 'KSLabel' if uncurated else 'group' - id_key = 'id' if 'id' in info else 'cluster_id' - grouping = 'KSLabel' if uncurated else 'group' - - for unit in info[id_key]: - unit_info = info.loc[info[id_key] == unit].iloc[0].to_dict() - - # we only want units that are in the specified group - if not group or unit_info[grouping] == group: - - # and that are within the specified depth range - if min_depth is not None: - if probe_depths[stream_num] - unit_info['depth'] <= min_depth: - continue - if max_depth is not None: - if probe_depths[stream_num] - unit_info['depth'] > max_depth: - continue - - # and that have the specified median spike widths - if widths is not None: - width = widths[widths['unit'] == unit]['median_ms'] - assert len(width.values) == 1 - if min_spike_width is not None: - if width.values[0] < min_spike_width: + for unit in info[id_key]: + unit_info = info.loc[info[id_key] == unit].iloc[0].to_dict() + + # we only want units that are in the specified group + if not group or unit_info[grouping] == group: + + # and that are within the specified depth range + if min_depth is not None: + if probe_depths[stream_num] - unit_info['depth'] <= min_depth: continue - if max_spike_width is not None: - if width.values[0] > max_spike_width: + if max_depth is not None: + if probe_depths[stream_num] - unit_info['depth'] > max_depth: continue - selected_units.append(unit) + # and that have the specified median spike widths + if widths is not None: + width = widths[widths['unit'] == unit]['median_ms'] + assert len(width.values) == 1 + if min_spike_width is not None: + if width.values[0] < min_spike_width: + continue + if max_spike_width is not None: + if width.values[0] > max_spike_width: + continue + + selected_units.append(unit) - return selected_units + return selected_units def _get_neuro_raw(self, kind): raw = [] From 8eb460f636d053240aa64a12160da79a71f23beb Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:24:45 +0000 Subject: [PATCH 110/658] re-index dict of df by the longest trial, then stack or concat --- pixels/ioutils.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 3281c1c..c44e279 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -670,3 +670,41 @@ def stream_video(video, length=None): length -= 1 if length == 0: break + +def reindex_by_longest(dfs, return_format="array", names=None): + """ + params + === + dfs: dict, dictionary with pandas dataframe as values. + + return_format: str, format of return value. + "array": stacked np array. + "dataframe": concatenated pandas dataframe. + + names: str or list of str, names for levels of concatenated dataframe. + + return + === + np.array or pd.DataFrame. + """ + # align all trials by index + indices = list(set().union( + *[df.index for df in dfs.values()]) + ) + + # reindex by the longest + reidx_dfs = {key: df.reindex(index=indices) + for key, df in dfs.items()} + + if return_format == "array": + # stack df values into np array + output = np.stack( + [df.values for df in reidx_dfs.values()], + axis=-1, + ) + elif return_format == "dataframe": + # concatenate dfs + output = pd.concat(reidx_dfs, axis=1, names=names) + + return output + From 356edd7020169897acf4818e6fe10a5da2137841 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:27:31 +0000 Subject: [PATCH 111/658] bin data by givein temporal and positional bin --- pixels/behaviours/base.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c85a77b..fe80d9f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2951,3 +2951,42 @@ def get_aligned_spike_rate_CI( df = pd.concat(cis, axis=1, names=['unit', 'bin']) df.set_index(pd.Index(percentiles, name="percentile"), inplace=True) return df + + + def bin_vr_trial(self, data, positions, time_bin, pos_bin, bin_method="mean"): + """ + Bin virtual reality trials by given temporal bin and positional bin. + """ + data = data.copy() + positions = positions.copy() + + # convert index to datetime index for resampling + isi = (1 / self.SAMPLE_RATE) * 1000 + data.index = pd.to_timedelta( + arg=data.index * isi, + unit="ms", + ) + + # set position index too + positions.index = data.index + + # resample to 100ms bin, and get position mean + mean_pos = positions.resample(time_bin).mean() + + if bin_method == "sum": + # resample to Xms bin, and get sum + bin_data = data.resample(time_bin).sum() + elif bin_method == "mean": + # resample to Xms bin, and get mean + bin_data = data.resample(time_bin).mean() + + # add position here to bin together + bin_data['positions'] = mean_pos.values + # add bin positions + bin_pos = mean_pos // pos_bin + 1 + bin_data['bin_pos'] = bin_pos.values + + # use numeric index + bin_data.reset_index(inplace=True, drop=True) + + return bin_data From 88db28502ebebc9ea1334a4dbb2db4c968793486 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:31:06 +0000 Subject: [PATCH 112/658] add doc --- pixels/behaviours/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index fe80d9f..4e05d2a 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2956,6 +2956,20 @@ def get_aligned_spike_rate_CI( def bin_vr_trial(self, data, positions, time_bin, pos_bin, bin_method="mean"): """ Bin virtual reality trials by given temporal bin and positional bin. + + params + === + data: pandas dataframe, neural data needed binning. + + positions: pandas dataframe, position of current trial. + + time_bin: str, temporal bin for neural data. + + pos_bin: int, positional bin for positions. + + bin_method: str, method to concatenate data within each temporal bin. + "mean": taking the mean of all frames. + "sum": taking sum of all frames. """ data = data.copy() positions = positions.copy() From a3f0abac2d2295b3985632137306f6762952968f Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:31:25 +0000 Subject: [PATCH 113/658] get directory of sorting analyser --- pixels/behaviours/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4e05d2a..a1018f3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1650,15 +1650,17 @@ def get_lfp_data(self): return self._get_processed_data("_lfp_data", "lfp_processed") - def _get_si_spike_times(self, sa_dir): + def _get_si_spike_times(self): """ get spike times in second with spikeinterface """ + self.sa_dir = self.find_file(self.files[0]["sorting_analyser"]) + spike_times = self._spike_times_data for stream_num, stream in enumerate(range(len(spike_times))): # load sorting analyser - sa = si.load_sorting_analyzer(sa_dir) + sa = si.load_sorting_analyzer(self.sa_dir) times = {} # get spike train From 7a7fcf62769d11f6253a354b3bb44ff7bcef5d3f Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:32:10 +0000 Subject: [PATCH 114/658] put getting spike times using spikeinterface as an option in get_spike_times --- pixels/behaviours/base.py | 65 ++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a1018f3..2536a33 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1682,7 +1682,7 @@ def _get_si_spike_times(self): return spike_times[0] # NOTE: only deal with one stream for now - def _get_spike_times(self, remapped=False): + def _get_spike_times(self, remapped=False, use_si=False): """ Returns the sorted spike times. @@ -1694,41 +1694,44 @@ def _get_spike_times(self, remapped=False): spike_times = self._spike_times_data for stream_num, stream in enumerate(range(len(spike_times))): - if remapped and stream_num > 0: - times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' - print(f"""\n> Found remapped spike times from\r - {self.ks_outputs[stream_num]}, try to load this.""") + if use_si: + spike_times[stream_num] = self._get_si_spike_times() else: - times = self.ks_outputs[stream_num] / f'spike_times.npy' + if remapped and stream_num > 0: + times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' + print(f"""\n> Found remapped spike times from\r + {self.ks_outputs[stream_num]}, try to load this.""") + else: + times = self.ks_outputs[stream_num] / f'spike_times.npy' - clust = self.ks_outputs[stream_num] / f'spike_clusters.npy' + clust = self.ks_outputs[stream_num] / f'spike_clusters.npy' - try: - times = np.load(times) - clust = np.load(clust) - except FileNotFoundError: - msg = ": Can't load spike times that haven't been extracted!" - raise PixelsError(self.name + msg) - - times = np.squeeze(times) - clust = np.squeeze(clust) - by_clust = {} + try: + times = np.load(times) + clust = np.load(clust) + except FileNotFoundError: + msg = ": Can't load spike times that haven't been extracted!" + raise PixelsError(self.name + msg) - for c in np.unique(clust): - c_times = times[clust == c] - uniques, counts = np.unique( - c_times, - return_counts=True, - ) - repeats = c_times[np.where(counts>1)] - if len(repeats>1): - print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + times = np.squeeze(times) + clust = np.squeeze(clust) + by_clust = {} - by_clust[c] = pd.Series(uniques) - spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) - # Convert to time into sample rate index - spike_times[stream_num] /= int(self.spike_meta[0]['imSampRate'])\ - / self.SAMPLE_RATE + for c in np.unique(clust): + c_times = times[clust == c] + uniques, counts = np.unique( + c_times, + return_counts=True, + ) + repeats = c_times[np.where(counts>1)] + if len(repeats>1): + print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + + by_clust[c] = pd.Series(uniques) + spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) + # Convert to time into sample rate index + spike_times[stream_num] /= int(self.spike_meta[0]['imSampRate'])\ + / self.SAMPLE_RATE return spike_times[0] From 9f08f3902223075a51efd6d2523b78bc89967aa6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:33:17 +0000 Subject: [PATCH 115/658] explicitly put output content in names --- pixels/behaviours/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2536a33..1d34239 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1849,9 +1849,11 @@ def _get_aligned_trials( """ action_labels = self.get_action_labels()[0] - # define output path for alfredo - output_al_path = self.interim/\ - f'cache/{self.name}_{label}_{units}_fr_for_AL.npz' + # define output path for binned spike rate + output_fr_path = self.interim/\ + f'cache/{self.name}_{label}_{units}_{time_bin}_spike_rate.npz' + output_count_path = self.interim/\ + f'cache/{self.name}_{label}_{units}_{time_bin}_spike_count.npz' if units is None: units = self.select_units() From 66a5c82b7e52afa04ed3918204cdcd358e0cdb98 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:33:55 +0000 Subject: [PATCH 116/658] use spikeinterface to get spike times --- pixels/behaviours/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1d34239..c60c696 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1866,9 +1866,7 @@ def _get_aligned_trials( #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! - #spikes = self._get_spike_times()[units] - sa_dir="/home/amz/synology/arthur/data/npx/interim/20240812_az_VDCN09/ks4/curated_sa.zarr" - spikes = self._get_si_spike_times(sa_dir)[units] + spikes = self._get_spike_times(use_si=True)[units] # drop rows if all nans spikes = spikes.dropna(how="all") From 4ec5ca86466cd765f85fc31d16efd1eb5464724d Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:34:17 +0000 Subject: [PATCH 117/658] use more informative names --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c60c696..3eca3c3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1904,7 +1904,8 @@ def _get_aligned_trials( cursor = 0 # In sample points rec_trials = {} trial_positions = {} - bin_trials = {} + bin_frs = {} + bin_counts = {} for rec_num in range(len(self.files)): # TODO jun 12 2024 skip other streams for now From 414299904861afd8d2255f31fc09846f3b2fe5fd Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:35:57 +0000 Subject: [PATCH 118/658] make `get_spike_times` public --- pixels/behaviours/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 3eca3c3..bf9c5e4 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -449,8 +449,8 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): if output.exists(): print(f'\n> Spike times from {self.ks_outputs[remap_stream_idx]}\ already remapped, next session.') - cluster_times = self._get_spike_times()[remap_stream_idx] - remapped_cluster_times = self._get_spike_times( + cluster_times = self.get_spike_times()[remap_stream_idx] + remapped_cluster_times = self.get_spike_times( remapped=True)[remap_stream_idx] # get first spike time from each cluster, and their difference @@ -544,8 +544,8 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): print(f'\n> Spike times remapping output saved to\n {output}.') # load remapped spike times of each cluster - cluster_times = self._get_spike_times()[remap_stream_idx] - remapped_cluster_times = self._get_spike_times( + cluster_times = self.get_spike_times()[remap_stream_idx] + remapped_cluster_times = self.get_spike_times( remapped=True)[remap_stream_idx] # get first spike time from each cluster, and their difference @@ -1682,7 +1682,7 @@ def _get_si_spike_times(self): return spike_times[0] # NOTE: only deal with one stream for now - def _get_spike_times(self, remapped=False, use_si=False): + def get_spike_times(self, remapped=False, use_si=False): """ Returns the sorted spike times. @@ -1751,7 +1751,7 @@ def _get_aligned_spike_times( #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! #TODO: spike times cannot be indexed by unit ids anymore - spikes = self._get_spike_times()[units] + spikes = self.get_spike_times()[units] if rate: # pad ends with 1 second extra to remove edge effects from convolution @@ -1866,7 +1866,7 @@ def _get_aligned_trials( #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! - spikes = self._get_spike_times(use_si=True)[units] + spikes = self.get_spike_times(use_si=True)[units] # drop rows if all nans spikes = spikes.dropna(how="all") From 8b167ce10214a6f85f28855d05c415ae63918356 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 17 Dec 2024 17:38:51 +0000 Subject: [PATCH 119/658] also export sum of spike count in each temporal bin --- pixels/behaviours/base.py | 103 ++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 59 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index bf9c5e4..761559c 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1980,89 +1980,74 @@ def _get_aligned_trials( sigma=sigma, sample_rate=self.SAMPLE_RATE, ) - # keep the original index - rates.index = spiked.index # remove 1s padding from the start and end rates = rates.iloc[scan_pad: -scan_pad] + spiked = spiked.iloc[scan_pad: -scan_pad] # reset index to zero at the beginning of the trial rates.reset_index(inplace=True, drop=True) + spiked.reset_index(inplace=True, drop=True) trial_pos.reset_index(inplace=True, drop=True) rec_trials[trial_ids[i]] = rates trial_positions[trial_ids[i]] = trial_pos - bin_trial = rates.copy() - # add position here to bin together - bin_trial['positions'] = trial_pos.values - - # get inter-sample-interval, time interval between each sample - # in milliseconds - isi = (1 / self.SAMPLE_RATE) * 1000 - # reset index to zero at the beginning of the trial - bin_trial.reset_index(inplace=True, drop=True) - # convert index to datetime index for resampling - bin_trial.index = pd.to_timedelta( - arg=bin_trial.index * isi, - unit="ms", + # get bin firing rates + bin_frs[i] = self.bin_vr_trial( + data=rates, + positions=trial_pos, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", + ) + # get bin spike count + bin_counts[i] = self.bin_vr_trial( + data=spiked, + positions=trial_pos, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", ) - # resample to 100ms bin - bin_trial = bin_trial.resample(time_bin).mean() - # use numeric index - bin_trial.index = np.arange(0, len(bin_trial)) - # bin positions and only save the bin index - # NOTE: here position bin index starts at 1, for alfredo - # to make it back to 0-indexing, remove +1 at the end - bin_trial['pos_bin'] = bin_trial['positions'] // pos_bin + 1 - - bin_trials[i] = bin_trial - - # align all trials by index - all_indices = list(set().union( - *[df.index for df in bin_trials.values()]) - ) - # reindex bin_trial by the longest trial - bin_dfs = {key: df.reindex(index=all_indices) - for key, df in bin_trials.items()} # stack df values into np array - bin_arr = np.stack( - [df.values for df in bin_dfs.values()], - axis=2, - ).T # reshape into trials x units x bins + # reshape into trials x units x bins + bin_count_arr = ioutils.reindex_by_longest(bin_counts).T + bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T - # save bin_arr, for alfredo + # save bin_fr and bin_count, for alfredo & andrew # use label as array key name - arr_to_save = { - "fr": bin_arr[:, :-2, :], - "pos": bin_arr[:, -2:, :], + fr_to_save = { + "fr": bin_fr_arr[:, :-2, :], + "pos": bin_fr_arr[:, -2:, :], + } + np.savez_compressed(output_fr_path, **fr_to_save) + print(f"> Output saved at {output_fr_path}.") + count_to_save = { + "count": bin_count_arr[:, :-2, :], + "pos": bin_count_arr[:, -2:, :], } - np.savez_compressed(output_al_path, **arr_to_save) - print(f"> Output saved at {output_al_path}.") + np.savez_compressed(output_count_path, **count_to_save) + print(f"> Output saved at {output_count_path}.") if not rec_trials: return None - # align all trials by index - rec_indices = list(set().union( - *[df.index for df in rec_trials.values()]) + # concat trial df + positions = ioutils.reindex_by_longest( + dfs=trial_positions, + return_format="dataframe", + names="trial", + ) + fr = ioutils.reindex_by_longest( + dfs=rec_trials, + return_format="dataframe", + names=["trial", "unit"], ) - # reindex all trials by the longest trial - raw_dfs = {key: df.reindex(index=rec_indices) - for key, df in rec_trials.items()} - raw_pos = {key: df.reindex(index=rec_indices) - for key, df in trial_positions.items()} - - positions = pd.concat(raw_pos, axis=1, names="trial") - - # TODO july 10 2024 shuffle spike times for each unit across the whole - # recording - - fr = pd.concat(raw_dfs, axis=1, names=["trial", "unit"]) fr = fr.reorder_levels(["unit", "trial"], axis=1) - fr.sort_index(level=0, axis=1) + fr = fr.sort_index(level=0, axis=1) return {"fr": fr, "positions": positions} + def select_units( self, group='good', min_depth=0, max_depth=None, min_spike_width=None, max_spike_width=None, uncurated=False, name=None, use_si=False, From b612ba22e68031a3345d9c76d1e136790a6ff1f1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 18 Dec 2024 19:50:48 +0000 Subject: [PATCH 120/658] add func to get postional firing rate for vr --- pixels/behaviours/base.py | 96 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 761559c..1d0a2ba 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2995,3 +2995,99 @@ def bin_vr_trial(self, data, positions, time_bin, pos_bin, bin_method="mean"): bin_data.reset_index(inplace=True, drop=True) return bin_data + + + @_cacheable + def get_positional_rate( + self, self, label, event, end_event=None, sigma=None, time_bin=None, + pos_bin=None, units=None, + ): + """ + Get positional firing rate of selected units in vr, and spatial + occupancy of each position. + """ + # TODO dec 18 2024: + # rearrange vr specific funcs to vr module + # put pixels specific funcs in pixels module + + TUNNEL_RESET = 600 # cm + + # get aligned firing rates and positions + trials = self.align_trials( + label=label, + event=event, + data="trial_rate", + end_event=end_event, + sigma=sigma, + time_bin=time_bin, + pos_bin=pos_bin, + units=units, + ) + fr = trials["fr"] + positions = trials["positions"] + + # get unit_ids + unit_ids = fr.columns.get_level_values("unit").unique() + + # create position indices + indices = np.arange(0, TUNNEL_RESET+2) + # create occupancy array for trials + occupancy = np.full( + (TUNNEL_RESET+2, positions.shape[1]), + np.nan, + ) + # create array for positional firing rate + pos_fr = {} + + for t, trial in enumerate(positions): + # get trial position + trial_pos = positions[trial].dropna() + + # floor pre reward zone and end ceil post zone end + trial_pos = trial_pos.apply( + lambda x: np.floor(x) if x <= ZONE_END else np.ceil(x) + ) + # set to int + trial_pos = trial_pos.astype(int) + + # exclude positions after tunnel reset + trial_pos = trial_pos[trial_pos <= TUNNEL_RESET+1] + + # get firing rates for current trial of all units + trial_fr = fr.xs( + key=trial, + axis=1, + level="trial", + ).dropna(how="all").copy() + + # get all indices before post reset + no_post_reset = trial_fr.index.intersection(trial_pos.index) + # remove post reset rows + trial_fr = trial_fr.loc[no_post_reset] + trial_pos = trial_pos.loc[no_post_reset] + + # put trial positions in trial fr df + trial_fr["position"] = trial_pos.values + # group values by position and get mean + mean_fr = trial_fr.groupby("position")[unit_ids].mean() + # reindex into full tunnel length + pos_fr[trial] = mean_fr.reindex(indices) + # get trial occupancy + pos_count = trial_fr.groupby("position").size() + occupancy[pos_count.index.values, t] = pos_count.values + + # concatenate dfs + pos_fr = pd.concat(pos_fr, axis=1, names=["trial", "unit"]) + pos_fr = pos_fr.reorder_levels(["unit", "trial"], axis=1) + pos_fr = pos_fr.sort_index(level="unit", axis=1) + # convert to df + occupancy = pd.DataFrame( + data=occupancy, + index=indices, + columns=positions.columns, + ) + + return {"pos_fr": pos_fr, "occupancy": occupancy} + + + From 5543979ac5ffa62380a999d229088d118322c163 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Dec 2024 17:38:55 +0000 Subject: [PATCH 121/658] add get_positional_rate at experiment level --- pixels/experiment.py | 49 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index aa7f297..daaab98 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -466,3 +466,52 @@ def get_session_by_name(self, name: str): if session.name == name: return session raise PixelsError + + + def get_positional_rate(self, *args, units=None, **kwargs): + """ + Get positional firing rate for aligned vr trials. + Check behaviours.base.Behaviour.align_trials for usage information. + """ + trials = {} + for i, session in enumerate(self.sessions): + result = None + if units: + if units[i]: + result = session.get_positional_rate( + *args, + units=units[i], + **kwargs, + ) + else: + result = session.get_positional_rate(*args, **kwargs) + if result is not None: + trials[i] = result + + pos_frs = {} + occupancies = {} + for s in trials: + pos_frs[s] = trials[s]["pos_fr"] + occupancies[s] = trials[s]["occupancy"] + + pos_frs_df = pd.concat( + pos_frs.values(), + axis=1, + copy=False, + keys=pos_frs.keys(), + names=["session"] + ) + occu_df = pd.concat( + occupancies.values(), + axis=1, + copy=False, + keys=occupancies.keys(), + names=["session"] + ) + df = { + "pos_fr": pos_frs_df, + "occupancy": occu_df, + } + + return df + From 0dc7978eeff5ef9597f5f366b88af35bae49170f Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Dec 2024 17:40:03 +0000 Subject: [PATCH 122/658] add print to show if cache is loaded --- pixels/behaviours/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1d0a2ba..b5fc332 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -86,10 +86,13 @@ def func(*args, **kwargs): as_list[i] = name as_list.insert(0, method.__name__) + output = self.interim / 'cache' / ('_'.join(str(i) for i in as_list) + '.h5') if output.exists() and self._use_cache != "overwrite": + # load cache try: df = ioutils.read_hdf5(output) + print(f"> Cache loaded from {output}.") except HDF5ExtError: df = None except (KeyError, ValueError): @@ -107,6 +110,7 @@ def func(*args, **kwargs): key_name = key.lstrip("/") # use key name as dict key df[key_name] = data + print(f"> Cache loaded from {output}.") else: df = method(*args, **kwargs) output.parent.mkdir(parents=True, exist_ok=True) From 69da82b578a7cbe91476d76d06111aae10c7c8c7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Dec 2024 17:40:34 +0000 Subject: [PATCH 123/658] change order of kwargs to match with Experiment so that the same cache file can be loaded --- pixels/behaviours/base.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b5fc332..5aafc15 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1839,8 +1839,8 @@ def _get_aligned_spike_times( def _get_aligned_trials( - self, label, event, end_event=None, sigma=None, time_bin=None, - pos_bin=None, units=None, + self, label, event, units=None, sigma=None, end_event=None, + time_bin=None, pos_bin=None, ): """ Returns spike rate for each unit within a trial. @@ -2208,9 +2208,9 @@ def get_lfp_data_raw(self): @_cacheable def align_trials( - self, label, event, data='spike_times', raw=False, duration=1, sigma=None, - units=None, dlc_project=None, video_match=None, end_event=None, - time_bin=None, pos_bin=False, + self, label, event, units=None, data='spike_times', raw=False, + duration=1, sigma=None, dlc_project=None, video_match=None, + end_event=None, time_bin=None, pos_bin=False, ): """ Get trials aligned to an event. This finds all instances of label in the action @@ -2291,8 +2291,8 @@ def align_trials( print(f"Aligning {data} to trials.") # we let a dedicated function handle aligning spike times return self._get_aligned_trials( - label, event, end_event=end_event, sigma=sigma, - time_bin=time_bin, pos_bin=pos_bin, units=units, + label, event, units=units, sigma=sigma, end_event=end_event, + time_bin=time_bin, pos_bin=pos_bin, ) if data == "motion_tracking" and not dlc_project: @@ -3003,7 +3003,7 @@ def bin_vr_trial(self, data, positions, time_bin, pos_bin, bin_method="mean"): @_cacheable def get_positional_rate( - self, self, label, event, end_event=None, sigma=None, time_bin=None, + self, label, event, end_event=None, sigma=None, time_bin=None, pos_bin=None, units=None, ): """ @@ -3013,19 +3013,23 @@ def get_positional_rate( # TODO dec 18 2024: # rearrange vr specific funcs to vr module # put pixels specific funcs in pixels module - TUNNEL_RESET = 600 # cm + ZONE_END = 495 # cm + + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded # get aligned firing rates and positions trials = self.align_trials( + units=units, label=label, event=event, data="trial_rate", - end_event=end_event, sigma=sigma, + end_event=end_event, time_bin=time_bin, pos_bin=pos_bin, - units=units, ) fr = trials["fr"] positions = trials["positions"] From 7d7bb73f50f6cdcf29279bb87f44249c88b6e83c Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 28 Dec 2024 16:15:36 +0000 Subject: [PATCH 124/658] import constants from vd --- pixels/behaviours/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 5aafc15..2e6c2d3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3010,11 +3010,12 @@ def get_positional_rate( Get positional firing rate of selected units in vr, and spatial occupancy of each position. """ + # get constants from vd + from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END + # TODO dec 18 2024: # rearrange vr specific funcs to vr module # put pixels specific funcs in pixels module - TUNNEL_RESET = 600 # cm - ZONE_END = 495 # cm # NOTE: order of args matters for loading the cache! # always put units first, cuz it is like that in @@ -3022,7 +3023,7 @@ def get_positional_rate( # get aligned firing rates and positions trials = self.align_trials( - units=units, + units=units, # NOTE: ALWAYS the first arg label=label, event=event, data="trial_rate", From aba67f26953561d665c52fc8cf0dbcee1337151f Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 28 Dec 2024 16:16:18 +0000 Subject: [PATCH 125/658] add starting positions as a level --- pixels/behaviours/base.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2e6c2d3..c2ada87 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3087,8 +3087,6 @@ def get_positional_rate( # concatenate dfs pos_fr = pd.concat(pos_fr, axis=1, names=["trial", "unit"]) - pos_fr = pos_fr.reorder_levels(["unit", "trial"], axis=1) - pos_fr = pos_fr.sort_index(level="unit", axis=1) # convert to df occupancy = pd.DataFrame( data=occupancy, @@ -3096,6 +3094,23 @@ def get_positional_rate( columns=positions.columns, ) + # add another level of starting position + # Get the starting index for each trial (column) + starts = occupancy.apply(lambda col: col.first_valid_index()) + # Group trials by their starting index + trial_level = pos_fr.columns.get_level_values("trial") + unit_level = pos_fr.columns.get_level_values("unit") + # map start level + start_level = trial_level.map(starts) + # define new columns + new_cols = pd.MultiIndex.from_arrays( + [start_level, unit_level, trial_level], + names=["start", "unit", "trial"], + ) + pos_fr.columns = new_cols + # sort by unit + pos_fr = pos_fr.sort_index(level="unit", axis=1) + return {"pos_fr": pos_fr, "occupancy": occupancy} From 11d254a708f61e535d800e9051a36ef971c24f9b Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:09:25 +0000 Subject: [PATCH 126/658] use wavpack for compression --- pixels/behaviours/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c2ada87..4d8c995 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -35,6 +35,7 @@ import spikeinterface.postprocessing as spost from scipy import interpolate from tables import HDF5ExtError +from wavpack_numcodecs import WavPack from pixels import ioutils from pixels import signal From f3905bac35bd3320e756f523fab8efe4d15596fa Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:09:47 +0000 Subject: [PATCH 127/658] use 80% cores; create generic wavpack compressor --- pixels/behaviours/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4d8c995..a2cfc7f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -50,13 +50,18 @@ # set si job_kwargs job_kwargs = dict( - n_jobs=-1, + n_jobs=0.8, # 80% core chunk_duration='1s', progress_bar=True, ) si.set_global_job_kwargs(**job_kwargs) +# instantiate WavPack compressor +wv_compressor = WavPack( + level=3, # high compression + bps=None, # lossless +) def _cacheable(method): """ From e90c309b142d35827314438c1b830854fecddde7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:10:15 +0000 Subject: [PATCH 128/658] copy spike meta by default --- pixels/behaviours/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a2cfc7f..84338dd 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -234,11 +234,8 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self.drop_data() self.spike_meta = [ - ioutils.read_meta(self.find_file(f['spike_meta'], copy=False)) for f in self.files + ioutils.read_meta(self.find_file(f['spike_meta'], copy=True)) for f in self.files ] - #self.lfp_meta = [ - # ioutils.read_meta(self.find_file(f['lfp_meta'], copy=False)) for f in self.files - #] # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache self.set_cache(bool(int(os.environ.get("PIXELS_CACHE", 1)))) From 30448e46e86320fb4b6ba2884b04beb6a970ab40 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:10:49 +0000 Subject: [PATCH 129/658] create func to preprocess full-band data and save as zarr; create ap band on the fly --- pixels/behaviours/base.py | 84 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 84338dd..e742876 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -610,20 +610,100 @@ def process_behaviour(self): print("> Done!") + + def preprocess_fullband(self, mc_method="dredge"): + """ + Preprocess full-band pixels data. + + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) + + return + === + preprocessed: spikeinterface recording. + """ + for rec_num, recording in enumerate(self.files): + print( + f">>>>> Preprocessing data for recording {rec_num+1} " + f"of {len(self.files)}" + ) + + output = self.interim / recording['preprocessed'] + if output.exists(): + continue + + # load recording data + rec = se.read_spikeglx( + folder_path=self.interim, + stream_id="imec0.ap", + stream_name=recording['spike_data'], + all_annotations=True, # include all annotations + ) + + # correct phase shift + print("> do phase shift correction on raw.") + rec_ps = spre.phase_shift(rec) + + print("> do common average referencing.") + cmr = spre.common_reference( + rec_ps, + dtype=np.float32, # change to float for motion correction + ) + + print(f"> correct motion with {mc_method}.") + mcd = spre.correct_motion( + cmr, + preset=mc_method, + #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + #folder=self.interim, + ) + + mcd.save( + format="zarr", + folder=output, + compressor=wv_compressor, + ) + + def process_spikes(self): """ Process the spike data from the raw neural recording data. """ + # preprocess data + self.preprocess_fullband() + for rec_num, recording in enumerate(self.files): print( f">>>>> Processing spike data for recording {rec_num + 1} of {len(self.files)}" ) output = self.processed / recording['spike_processed'] - if output.exists(): continue - data_file = self.find_file(recording['spike_data']) + # load preprocessed + preprocessed = self.find_file(recording['preprocessed']) + rec = si.load_extractor(preprocessed) + + print("> create ap band by high-pass filtering.") + ap_band = spre.bandpass_filter( + rec, + freq_min=300, + freq_max=9000, + ftype="butterworth", + ) + + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + downsampled = spre.resample(ap_band, self.SAMPLE_RATE) + + downsampled.save( + format="zarr", + folder=output, + compressor=wv_compressor, + ) + """ orig_rate = self.spike_meta[rec_num]['imSampRate'] num_chans = self.spike_meta[rec_num]['nSavedChans'] From b5daffd82bf57046670f474988fc8ba948b19fe2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:11:39 +0000 Subject: [PATCH 130/658] use preprocessed data for lfp process too --- pixels/behaviours/base.py | 42 +++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e742876..da402ac 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -738,13 +738,42 @@ def process_lfp(self): """ Process the LFP data from the raw neural recording data. """ + # preprocess data + self.preprocess_fullband() + for rec_num, recording in enumerate(self.files): print(f">>>>> Processing LFP for recording {rec_num + 1} of {len(self.files)}") output = self.processed / recording['lfp_processed'] if output.exists(): continue + assert 0 + + # load preprocessed + preprocessed = self.processed / recording['preprocessed'] + rec = se.load_extractor(preprocessed) + + # get lfp band + lfp_band = spre.bandpass_filter( + rec, + freq_min=0.5, + freq_max=300, + ftype="butterworth", + ) + + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + downsampled = spre.resample(lfp_band, self.SAMPLE_RATE) + downsampled.save( + format="zarr", + folder=output, + compressor=wv_compressor, + ) + assert 0 + # get traces + traces = downsampled.get_traces() + + """ data_file = self.find_file(recording['lfp_data']) orig_rate = int(self.lfp_meta[rec_num]['imSampRate']) num_chans = int(self.lfp_meta[rec_num]['nSavedChans']) @@ -770,12 +799,13 @@ def process_lfp(self): #if self._lag[rec_num] is None: # self.sync_data(rec_num, sync_channel=data[:, -1]) #lag_start, lag_end = self._lag[rec_num] + """ sd = self.processed / recording['lfp_sd'] if sd.exists(): continue - SDs = np.std(downsampled, axis=0) + SDs = np.std(traces, axis=0) results = dict( median=np.median(SDs), SDs=SDs.tolist(), @@ -790,11 +820,11 @@ def process_lfp(self): # data = data[- lag_start:] #print(f"> Saving median subtracted & downsampled LFP to {output}") # save in .npy format - np.save( - file=output, - arr=downsampled, - allow_pickle=True, - ) + #np.save( + # file=output, + # arr=downsampled, + # allow_pickle=True, + #) #downsampled = pd.DataFrame(downsampled) #ioutils.write_hdf5(output, downsampled) From 1876eced92f4b338769355e0522c7832c5913888 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:12:44 +0000 Subject: [PATCH 131/658] do not search for lfp data since all data is full-band --- pixels/ioutils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index c44e279..e16ee3f 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -51,18 +51,12 @@ def get_data_files(data_dir, session_name): spike_data = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*')) spike_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*')) - lfp_data = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].lf.bin*')) - lfp_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].lf.meta*')) behaviour = sorted(glob.glob(f'{data_dir}/[0-9a-zA-Z_-]*([0-9]).tdms*')) if not spike_data: raise PixelsError(f"{session_name}: could not find raw AP data file.") if not spike_meta: raise PixelsError(f"{session_name}: could not find raw AP metadata file.") - #if not lfp_data: - # raise PixelsError(f"{session_name}: could not find raw LFP data file.") - #if not lfp_meta: - # raise PixelsError(f"{session_name}: could not find raw LFP metadata file.") camera_data = [] camera_meta = [] From df7c050a4e68999e7b231a0d365203ad6aab2079 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:15:26 +0000 Subject: [PATCH 132/658] use spike data name as name stem --- pixels/ioutils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index e16ee3f..57698e7 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -72,8 +72,19 @@ def get_data_files(data_dir, session_name): recording = {} recording['spike_data'] = original_name(spike_recording) recording['spike_meta'] = original_name(spike_meta[num]) - #recording['lfp_data'] = original_name(lfp_data[num]) - #recording['lfp_meta'] = original_name(lfp_meta[num]) + recording['lfp_data'] = recording['spike_data'].with_name( + recording['spike_data'].stem[:-3] + '.lf.zarr' + ) + + recording['spike_processed'] = recording['spike_data'].with_name( + recording['spike_data'].stem[:-3] + '.ap.processed.zarr' + ) + recording['lfp_processed'] = recording['spike_data'].with_name( + recording['spike_data'].stem[:-3] + '.lf.processed.zarr' + ) + recording['lfp_sd'] = recording['spike_data'].with_name( + recording['spike_data'].stem[:-3] + '_lf_sd.json' + ) if behaviour: if len(behaviour) == len(spike_data): From 53b00cc0a9338cb7409af23e56f5993601931f8c Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 6 Jan 2025 19:16:10 +0000 Subject: [PATCH 133/658] use spike data name as name stem --- pixels/ioutils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 57698e7..2c65565 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -106,19 +106,13 @@ def get_data_files(data_dir, session_name): recording['behaviour_processed'] = None recording['action_labels'] = Path(f'action_labels_{num}.npz') - recording['spike_processed'] = recording['spike_data'].with_name( - recording['spike_data'].stem + '_processed.h5' + recording['preprocessed'] = recording['spike_data'].with_name( + recording['spike_data'].stem[:-3] + '.preprocessed.zarr' ) recording['spike_rate_processed'] = Path(f'spike_rate_{num}.h5') - #recording['lfp_processed'] = recording['lfp_data'].with_name( - # recording['lfp_data'].stem + '_processed.npy' - #) - #recording['lfp_sd'] = recording['lfp_data'].with_name( - # recording['lfp_data'].stem + '_sd.json' - #) - #recording['clustered_channels'] = recording['lfp_data'].with_name( - # f'channel_clustering_results_{num}.h5' - #) + recording['clustered_channels'] = recording['spike_data'].with_name( + f'channel_clustering_results_{num}.h5' + ) recording['depth_info'] = recording['spike_data'].with_name( f'depth_info_{num}.json' ) From c1a885088357bae8b403474d4c1e89c2a281e4c7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 10 Jan 2025 20:16:55 +0000 Subject: [PATCH 134/658] use stream_id to name stream unique processed files to allow run concatenation --- pixels/ioutils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 2c65565..630a5af 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -76,15 +76,19 @@ def get_data_files(data_dir, session_name): recording['spike_data'].stem[:-3] + '.lf.zarr' ) + stream_id = recording['spike_data'].stem[-8:-3] + recording['preprocessed'] = recording['spike_data'].with_name( + f'{session_name}_{stream_id}.preprocessed.zarr' + ) recording['spike_processed'] = recording['spike_data'].with_name( - recording['spike_data'].stem[:-3] + '.ap.processed.zarr' + f'{session_name}_{stream_id}.ap.processed.zarr' ) recording['lfp_processed'] = recording['spike_data'].with_name( - recording['spike_data'].stem[:-3] + '.lf.processed.zarr' - ) - recording['lfp_sd'] = recording['spike_data'].with_name( - recording['spike_data'].stem[:-3] + '_lf_sd.json' + f'{session_name}_{stream_id}.lf.processed.zarr' ) + #recording['lfp_sd'] = recording['spike_data'].with_name( + # f'{session_name}_{stream_id}_lf_sd.json' + #) if behaviour: if len(behaviour) == len(spike_data): From cbc1b36df9ff32e33f6d3e95d70747cf5d2416b1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 10 Jan 2025 20:18:14 +0000 Subject: [PATCH 135/658] add curated sorting analyser --- pixels/ioutils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 630a5af..213c919 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -125,7 +125,9 @@ def get_data_files(data_dir, session_name): recording['vr'] = recording['spike_data'].with_name( f'{session_name}_vr_synched.h5' ) - recording['sorting_analyser'] = Path(f'ks4/curated_sa.zarr') + recording['sorting_analyser'] = recording['spike_meta'].with_name( + f'curated_sa.zarr' + ) files.append(recording) From 9f3e5721ed39401c04f7c82f78a0d0d962e92e79 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 10 Jan 2025 20:18:28 +0000 Subject: [PATCH 136/658] move preprocessed up --- pixels/ioutils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 213c919..a42eecf 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -110,9 +110,6 @@ def get_data_files(data_dir, session_name): recording['behaviour_processed'] = None recording['action_labels'] = Path(f'action_labels_{num}.npz') - recording['preprocessed'] = recording['spike_data'].with_name( - recording['spike_data'].stem[:-3] + '.preprocessed.zarr' - ) recording['spike_rate_processed'] = Path(f'spike_rate_{num}.h5') recording['clustered_channels'] = recording['spike_data'].with_name( f'channel_clustering_results_{num}.h5' From aba3de85422918c18225eaef8404c696d7428e0e Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 13 Jan 2025 19:11:29 +0000 Subject: [PATCH 137/658] change files into a nested dictionary so that data from the same probe is grouped together --- pixels/ioutils.py | 175 ++++++++++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 74 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index a42eecf..af9ddf4 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -33,100 +33,127 @@ def get_data_files(data_dir, session_name): Returns ------- - A list of dicts, where each dict corresponds to one recording. The dict will contain - these keys to identify data files: - - - spike_data - - spike_meta - - lfp_data - - lfp_meta - - behaviour - - camera_data - - camera_meta - + A nested dicts, where each dict corresponds to one session. Data is + separated to two main categories: pixels and behaviour. + + In `pixels`, data is separated by their stream id, this is to allow: + - easy concatenation of recordings files from the same probe, i.e., + stream id, + - different numbers of pixels recordings and behaviour recordings, + + {session_name:{ + 'pixels':{ + 'imec0':{ + 'ap_raw': [PosixPath('name.bin')], + 'ap_meta': [PosixPath('name.meta')], + 'preprocessed': PosixPath('name.zarr'), + 'ap_downsampled': PosixPath('name.zarr'), + 'lfp_downsampled': PosixPath('name.zarr'), + 'depth_info': PosixPath('name.json'), <== ?? + 'sorting_analyser': PosixPath('name.zarr'), + }, + 'imecN':{ + }, + }, + 'behaviour':{ + 'vr': PosixPath('name.h5'), + 'action_labels': PosixPath('name.npz'), + }, + } """ if session_name != data_dir.stem: data_dir = list(data_dir.glob(f'{session_name}*'))[0] - files = [] - spike_data = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*')) - spike_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*')) - behaviour = sorted(glob.glob(f'{data_dir}/[0-9a-zA-Z_-]*([0-9]).tdms*')) + files = {} + + ap_raw = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*')) + ap_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*')) - if not spike_data: + if not ap_raw: raise PixelsError(f"{session_name}: could not find raw AP data file.") - if not spike_meta: + if not ap_meta: raise PixelsError(f"{session_name}: could not find raw AP metadata file.") - camera_data = [] - camera_meta = [] - for rec in behaviour: - name = Path(rec).stem - rec_vids = sorted(glob.glob(f'{data_dir}/*{name}-*.tdms*')) - vids = [v for v in rec_vids if 'meta' not in v] - camera_data.append(vids) - meta = [v for v in rec_vids if 'meta' in v] - camera_meta.append(meta) - - for num, spike_recording in enumerate(spike_data): - recording = {} - recording['spike_data'] = original_name(spike_recording) - recording['spike_meta'] = original_name(spike_meta[num]) - recording['lfp_data'] = recording['spike_data'].with_name( - recording['spike_data'].stem[:-3] + '.lf.zarr' + pixels = {} + for r, rec in enumerate(ap_raw): + stream_id = rec[-12:-4] + # separate recordings by their stream ids + if stream_id not in pixels: + pixels[stream_id] = { + "ap_raw": [], # there could be mutliple, thus list + "ap_meta": [], + "si_rec": None, # there could be only one, thus None + "CatGT_ap_data": [], + "CatGT_ap_meta": [], + } + + base_name = original_name(rec) + pixels[stream_id]["ap_raw"].append(base_name) + pixels[stream_id]["ap_meta"].append(original_name(ap_meta[r])) + + # spikeinterface cache + pixels[stream_id]["preprocessed"] = base_name.with_name( + f'{session_name}_{stream_id}.preprocessed.zarr' + ) + pixels[stream_id]['sorting_analyser'] = base_name.with_name( + f'curated_sa.zarr' ) - stream_id = recording['spike_data'].stem[-8:-3] - recording['preprocessed'] = recording['spike_data'].with_name( - f'{session_name}_{stream_id}.preprocessed.zarr' + # downsampled ap stream, 300Hz+ + pixels[stream_id]["ap_downsampled"] = base_name.with_name( + f'{session_name}_{stream_id}.downsampled.zarr' ) - recording['spike_processed'] = recording['spike_data'].with_name( - f'{session_name}_{stream_id}.ap.processed.zarr' + # downsampled lfp stream, 300Hz- + pixels[stream_id]["lfp_downsampled"] = base_name.with_name( + f'{session_name}_{stream_id[:-3]}.lf.downsampled.zarr' ) - recording['lfp_processed'] = recording['spike_data'].with_name( - f'{session_name}_{stream_id}.lf.processed.zarr' + + # depth info of probe + pixels[stream_id]['depth_info'] = base_name.with_name( + f'depth_info_{stream_id}.json' ) - #recording['lfp_sd'] = recording['spike_data'].with_name( - # f'{session_name}_{stream_id}_lf_sd.json' + pixels[stream_id]['clustered_channels'] = base_name.with_name( + f'channel_clustering_results_{stream_id}.h5' + ) + + # old catgt data + pixels[stream_id]['CatGT_ap_data'].append( + str(base_name).replace("t0", "tcat") + ) + pixels[stream_id]['CatGT_ap_meta'].append( + str(base_name).replace("t0", "tcat") + ) + + #pixels[stream_id]['spike_rate_processed'] = base_name.with_name( + # f'spike_rate_{stream_id}.h5' #) - if behaviour: - if len(behaviour) == len(spike_data): - recording['behaviour'] = original_name(behaviour[num]) - else: - recording['behaviour'] = original_name(behaviour[0]) - recording['behaviour_processed'] = recording['behaviour'].with_name( - recording['behaviour'].stem + '_processed.h5' - ) - - # We only have videos if we also have behavioural TDMS data - if len(camera_data) > num: - recording['camera_data'] = [original_name(d) for d in camera_data[num]] - recording['camera_meta'] = [original_name(d) for d in camera_meta[num]] - recording['motion_index'] = Path(f'motion_index_{num}.npy') - recording['motion_tracking'] = Path(f'motion_tracking_{num}.h5') - else: - recording['behaviour'] = None - recording['behaviour_processed'] = None + pupil_raw = sorted(glob.glob(f'{data_dir}/behaviour/pupil_cam/*.avi*')) - recording['action_labels'] = Path(f'action_labels_{num}.npz') - recording['spike_rate_processed'] = Path(f'spike_rate_{num}.h5') - recording['clustered_channels'] = recording['spike_data'].with_name( - f'channel_clustering_results_{num}.h5' - ) - recording['depth_info'] = recording['spike_data'].with_name( - f'depth_info_{num}.json' + behaviour = { + "pupil_raw": pupil_raw, + } + + behaviour['vr_synched'] = base_name.with_name( + f'{session_name}_vr_synched.h5' + ) + behaviour['action_labels'] = base_name.with_name(f'action_labels.npz') + + if pupil_raw: + behaviour['pupil_processed'] = base_name.with_name( + session_name + '_pupil_processed.h5' ) - recording['CatGT_ap_data'] = str(recording['spike_data']).replace("t0", "tcat") - recording['CatGT_ap_meta'] = str(recording['spike_meta']).replace("t0", "tcat") - recording['vr'] = recording['spike_data'].with_name( - f'{session_name}_vr_synched.h5' + behaviour['motion_index'] = base_name.with_name( + session_name + 'motion_index.npz' ) - recording['sorting_analyser'] = recording['spike_meta'].with_name( - f'curated_sa.zarr' + behaviour['motion_tracking'] = base_name.with_name( + session_name + 'motion_tracking.h5' ) - files.append(recording) + files = { + "pixels": pixels, + "behaviour": behaviour, + } return files From 80923fd6a6ee402cdcf873a05ea575d7d3025cb7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 13 Jan 2025 19:13:10 +0000 Subject: [PATCH 138/658] make sure to get all raw ap meta from each stream --- pixels/behaviours/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index da402ac..b88e807 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -233,9 +233,12 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self._probe_depths = None self.drop_data() - self.spike_meta = [ - ioutils.read_meta(self.find_file(f['spike_meta'], copy=True)) for f in self.files - ] + self.ap_meta = [] + for stream_id in self.files["pixels"]: + for meta in self.files["pixels"][stream_id]["ap_meta"]: + self.ap_meta.append( + ioutils.read_meta(self.find_file(meta, copy=True)) + ) # environmental variable PIXELS_CACHE={0,1} can be {disable,enable} cache self.set_cache(bool(int(os.environ.get("PIXELS_CACHE", 1)))) From f05994599b1bf04909cb8d9646b86dc2c0a52a03 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 13 Jan 2025 19:16:14 +0000 Subject: [PATCH 139/658] separate implementation of preprocessing and getting files --- pixels/behaviours/base.py | 88 ++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b88e807..4d55056 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -613,10 +613,32 @@ def process_behaviour(self): print("> Done!") + def _preprocess_raw(self, rec, mc_method): + """ + Implementation of preprocessing on raw pixels data. + """ + # correct phase shift + print("> do phase shift correction on raw.") + rec_ps = spre.phase_shift(rec) + + print("> do common average referencing.") + # NOTE: dtype will be converted to float32 during motion correction + cmr = spre.common_reference( + rec_ps, + ) + + print(f"> correct motion with {mc_method}.") + mcd = spre.correct_motion( + cmr, + preset=mc_method, + #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + ) + + return mcd - def preprocess_fullband(self, mc_method="dredge"): + def preprocess_raw(self, mc_method="dredge"): """ - Preprocess full-band pixels data. + Preprocess full-band raw pixels data. params === @@ -628,55 +650,53 @@ def preprocess_fullband(self, mc_method="dredge"): === preprocessed: spikeinterface recording. """ - for rec_num, recording in enumerate(self.files): - print( - f">>>>> Preprocessing data for recording {rec_num+1} " - f"of {len(self.files)}" - ) + # load raw recording as si recording extractor + self.load_raw_ap() - output = self.interim / recording['preprocessed'] + # get pixels streams + streams = self.files["pixels"] + + for stream_id in streams: + # check if exists + output = self.interim / streams[stream_id]['preprocessed'] if output.exists(): continue - # load recording data - rec = se.read_spikeglx( - folder_path=self.interim, - stream_id="imec0.ap", - stream_name=recording['spike_data'], - all_annotations=True, # include all annotations - ) - - # correct phase shift - print("> do phase shift correction on raw.") - rec_ps = spre.phase_shift(rec) + # load si rec + rec = streams[stream_id]["si_rec"] - print("> do common average referencing.") - cmr = spre.common_reference( - rec_ps, - dtype=np.float32, # change to float for motion correction + print( + f">>>>> Preprocessing data for recording from {stream_id} " + f"of {len(streams)}" ) - print(f"> correct motion with {mc_method}.") - mcd = spre.correct_motion( - cmr, - preset=mc_method, - #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, - #folder=self.interim, - ) + shank_groups = rec.get_channel_groups() + if not np.all(shank_groups == shank_groups[0]): + preprocessed = [] + # split by groups + groups = rec.split_by("group") + for group in groups.values(): + preprocessed.append(self._preprocess_raw(group)) + # aggregate groups together + preprocessed = si.aggregate_channels(preprocessed) + else: + preprocessed = self._preprocess_raw(rec, mc_method) - mcd.save( + preprocessed.save( format="zarr", folder=output, compressor=wv_compressor, ) + return None + def process_spikes(self): """ Process the spike data from the raw neural recording data. """ - # preprocess data - self.preprocess_fullband() + # preprocess raw data + self.preprocess_raw() for rec_num, recording in enumerate(self.files): print( @@ -742,7 +762,7 @@ def process_lfp(self): Process the LFP data from the raw neural recording data. """ # preprocess data - self.preprocess_fullband() + self.preprocess_raw() for rec_num, recording in enumerate(self.files): print(f">>>>> Processing LFP for recording {rec_num + 1} of {len(self.files)}") From e133702024f77ecf795cbc7c259ea36283dc68aa Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 13 Jan 2025 19:16:42 +0000 Subject: [PATCH 140/658] remove excessive lines --- pixels/behaviours/base.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4d55056..1e46c0f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -793,31 +793,6 @@ def process_lfp(self): compressor=wv_compressor, ) assert 0 - # get traces - traces = downsampled.get_traces() - - """ - data_file = self.find_file(recording['lfp_data']) - orig_rate = int(self.lfp_meta[rec_num]['imSampRate']) - num_chans = int(self.lfp_meta[rec_num]['nSavedChans']) - - print("> Mapping LFP data") - data = se.read_binary(data_file, orig_rate, np.int16, num_chans-1) - #data = ioutils.read_bin(data_file, num_chans) - - print("> Performing median subtraction across channels for each timepoint") - #subtracted = signal.median_subtraction(data, axis=1) - cmred = spre.common_reference(data) - - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") - #downsampled = signal.resample(subtracted, orig_rate, self.SAMPLE_RATE, False) - downsampled = spre.resample(cmred, self.SAMPLE_RATE) - # get traces - downsampled = downsampled.get_traces() - - # TODO jun 10 2024 to get sync channel here and find lag? - #sync_chan = downsampled[:, -1] - #downsampled = downsampled[:, :-1] #if self._lag[rec_num] is None: # self.sync_data(rec_num, sync_channel=data[:, -1]) From a3e8351d10ca1b2a57ef3fab97d82f58faf3143e Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 13 Jan 2025 19:17:04 +0000 Subject: [PATCH 141/658] make sure raw data is loaded correctly with the new self.files structure --- pixels/behaviours/base.py | 90 ++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 53 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1e46c0f..e0caa2d 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -894,82 +894,66 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' )) - - def load_recording(self): + def load_raw_ap(self): """ - Write a function to load recording. + Write a function to load and concatenate raw recording files for each + stream (i.e., probe), so that data from all runs of the same probe can + be preprocessed and sorted together. """ - try: - recording = si.load_extractor(self.interim / 'cache/recording.json') - concat_rec = recording - output = os.path.dirname(self.ks_output) - return recording, output - - except: - for _, files in enumerate(self.files): - # TODO jan 4 check if can put line 798-808 here + # if multiple runs for the same probe, concatenate them + streams = self.files["pixels"] + + for stream_id in streams: + stream_files = streams[stream_id] + recs = [] + for r, raw in enumerate(stream_files["ap_raw"]): try: - print("\n> Getting catgt-ed recording...") self.CatGT_dir = Path(self.CatGT_dir[0]) data_dir = self.CatGT_dir - data_file = data_dir / files['CatGT_ap_data'] - metadata = data_dir / files['CatGT_ap_meta'] + data_file = data_dir / stream_files['CatGT_ap_data'][r] + print("\n> Got catgt-ed recording.") except: print(f"\n> Getting the orignial recording...") - data_file = self.find_file(files['spike_data']) - metadata = self.find_file(files['spike_meta']) - - assert 0 - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: - streams[stream_id] = metadata - - for stream_num, stream in enumerate(streams.items()): - stream_id, metadata = stream - # find spike sorting output folder - if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: - output = self.processed / f'sorted_stream_cat_{stream_num}' - else: - output = self.processed / f'sorted_stream_{stream_num}' + data_file = self.find_file(raw) - try: - recording = se.SpikeGLXRecordingExtractor(self.CatGT_dir, stream_id=stream_id) - except ValueError as e: - raise PixelsError( - f"Did the raw data get fully copied to interim? Full error: {e}\n" - ) + # load recording file + rec = se.read_spikeglx( + folder_path=data_file.parent, + stream_id=stream_id, + stream_name=data_file.stem, + all_annotations=True, # include all annotations + ) + recs.append(rec) - # this recording is filtered - recording.annotate(is_filtered=True) + if len(recs) > 1: + # concatenate runs for each probe + concat_recs = si.concatenate_recordings(recs) + else: + concat_recs = recs[0] - # concatenate recording segments - concat_rec = si.concatenate_recordings([recording]) - probe = pi.read_spikeglx(metadata.as_posix()) - concat_rec = concat_rec.set_probe(probe) - # annotate spike data is filtered - concat_rec.annotate(is_filtered=True) + # now the value for streams dict is recording extractor + stream_files["si_rec"] = concat_recs - return concat_rec, output + return None def sort_spikes(self, CatGT_app=None, old=False): """ Run kilosort spike sorting on raw spike data. """ - streams = {} - # set chunks for spikeinterface operations - #job_kwargs = dict( - # n_jobs=-3, # -1: num of job equals num of cores - # chunk_duration="1s", - # progress_bar=True, - #) + # preprocess raw + self.preprocess_raw() + + assert 0 + # TODO jan 13 2025: + # CONTINUE HERE! + # put ks4 here #concat_rec, output = self.load_recording() #assert 0 #TODO: jan 3 see if ks can run normally now using load_recording() self.run_catgt(CatGT_app=CatGT_app) - assert 0 for _, files in enumerate(self.files): if not CatGT_app == None: From a40db969987c6a5df07d44012e789e8fc1aea256 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:33:55 +0000 Subject: [PATCH 142/658] use double quote by default --- pixels/ioutils.py | 146 ++++++++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 70 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index af9ddf4..332ad50 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -42,32 +42,32 @@ def get_data_files(data_dir, session_name): - different numbers of pixels recordings and behaviour recordings, {session_name:{ - 'pixels':{ - 'imec0':{ - 'ap_raw': [PosixPath('name.bin')], - 'ap_meta': [PosixPath('name.meta')], - 'preprocessed': PosixPath('name.zarr'), - 'ap_downsampled': PosixPath('name.zarr'), - 'lfp_downsampled': PosixPath('name.zarr'), - 'depth_info': PosixPath('name.json'), <== ?? - 'sorting_analyser': PosixPath('name.zarr'), + "pixels":{ + "imec0":{ + "ap_raw": [PosixPath("name.bin")], + "ap_meta": [PosixPath("name.meta")], + "preprocessed": PosixPath("name.zarr"), + "ap_downsampled": PosixPath("name.zarr"), + "lfp_downsampled": PosixPath("name.zarr"), + "depth_info": PosixPath("name.json"), <== ?? + "sorting_analyser": PosixPath("name.zarr"), }, - 'imecN':{ + "imecN":{ }, }, - 'behaviour':{ - 'vr': PosixPath('name.h5'), - 'action_labels': PosixPath('name.npz'), + "behaviour":{ + "vr": PosixPath("name.h5"), + "action_labels": PosixPath("name.npz"), }, } """ if session_name != data_dir.stem: - data_dir = list(data_dir.glob(f'{session_name}*'))[0] + data_dir = list(data_dir.glob(f"{session_name}*"))[0] files = {} - ap_raw = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*')) - ap_meta = sorted(glob.glob(f'{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*')) + ap_raw = sorted(glob.glob(f"{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.bin*")) + ap_meta = sorted(glob.glob(f"{data_dir}/{session_name}_g[0-9]_t0.imec[0-9].ap.meta*")) if not ap_raw: raise PixelsError(f"{session_name}: could not find raw AP data file.") @@ -93,62 +93,68 @@ def get_data_files(data_dir, session_name): # spikeinterface cache pixels[stream_id]["preprocessed"] = base_name.with_name( - f'{session_name}_{stream_id}.preprocessed.zarr' + f"{session_name}_{stream_id}.preprocessed.zarr" ) - pixels[stream_id]['sorting_analyser'] = base_name.with_name( - f'curated_sa.zarr' + pixels[stream_id]["sorting_analyser"] = base_name.with_name( + f"curated_sa.zarr" ) - # downsampled ap stream, 300Hz+ - pixels[stream_id]["ap_downsampled"] = base_name.with_name( - f'{session_name}_{stream_id}.downsampled.zarr' + # extracted ap stream, 300Hz+ + pixels[stream_id]["ap_extracted"] = base_name.with_name( + f"{session_name}_{stream_id}.extracted.zarr" ) - # downsampled lfp stream, 300Hz- - pixels[stream_id]["lfp_downsampled"] = base_name.with_name( - f'{session_name}_{stream_id[:-3]}.lf.downsampled.zarr' + # extracted lfp stream, 300Hz- + pixels[stream_id]["lfp_extracted"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}.lf.extracted.zarr" ) # depth info of probe - pixels[stream_id]['depth_info'] = base_name.with_name( - f'depth_info_{stream_id}.json' + pixels[stream_id]["depth_info"] = base_name.with_name( + f"depth_info_{stream_id}.json" ) - pixels[stream_id]['clustered_channels'] = base_name.with_name( - f'channel_clustering_results_{stream_id}.h5' + pixels[stream_id]["clustered_channels"] = base_name.with_name( + f"channel_clustering_results_{stream_id}.h5" ) # old catgt data - pixels[stream_id]['CatGT_ap_data'].append( + pixels[stream_id]["CatGT_ap_data"].append( str(base_name).replace("t0", "tcat") ) - pixels[stream_id]['CatGT_ap_meta'].append( + pixels[stream_id]["CatGT_ap_meta"].append( str(base_name).replace("t0", "tcat") ) - #pixels[stream_id]['spike_rate_processed'] = base_name.with_name( - # f'spike_rate_{stream_id}.h5' + #pixels[stream_id]["spike_rate_processed"] = base_name.with_name( + # f"spike_rate_{stream_id}.h5" #) - pupil_raw = sorted(glob.glob(f'{data_dir}/behaviour/pupil_cam/*.avi*')) + pupil_raw = sorted(glob.glob(f"{data_dir}/behaviour/pupil_cam/*.avi*")) behaviour = { + "vr_synched": [], + "action_labels": [], "pupil_raw": pupil_raw, } - behaviour['vr_synched'] = base_name.with_name( - f'{session_name}_vr_synched.h5' - ) - behaviour['action_labels'] = base_name.with_name(f'action_labels.npz') + behaviour["vr_synched"].append(base_name.with_name( + f"{session_name}_vr_synched.h5" + )) + behaviour["action_labels"].append(base_name.with_name(f"action_labels.npz")) if pupil_raw: - behaviour['pupil_processed'] = base_name.with_name( - session_name + '_pupil_processed.h5' - ) - behaviour['motion_index'] = base_name.with_name( - session_name + 'motion_index.npz' - ) - behaviour['motion_tracking'] = base_name.with_name( - session_name + 'motion_tracking.h5' - ) + behaviour["pupil_processed"] = [] + behaviour["motion_index"] = [] + behaviour["motion_tracking"] = [] + for r, rec in enumerate(pupil_raw): + behaviour["pupil_processed"].append(base_name.with_name( + session_name + "_pupil_processed.h5" + )) + behaviour["motion_index"] = base_name.with_name( + session_name + "motion_index.npz" + ) + behaviour["motion_tracking"] = base_name.with_name( + session_name + "motion_tracking.h5" + ) files = { "pixels": pixels, @@ -163,7 +169,7 @@ def original_name(path): Get the original name of a file, uncompressed, as a pathlib.Path. """ name = os.path.basename(path) - if name.endswith('.tar.gz'): + if name.endswith(".tar.gz"): name = name[:-7] return Path(name) @@ -208,13 +214,13 @@ def read_bin(path, num_chans, channel=None): Returns ------- numpy.memmap array : A 2D memory-mapped array containing containing the binary - file's data. + file"s data. """ if not isinstance(num_chans, int): num_chans = int(num_chans) - mapping = np.memmap(path, np.int16, mode='r').reshape((-1, num_chans)) + mapping = np.memmap(path, np.int16, mode="r").reshape((-1, num_chans)) if channel is not None: mapping = mapping[:, channel] @@ -263,10 +269,10 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): Parameters ---------- video : numpy.ndarray, or generator - Video data to save to file. It's dimensions should be (duration, height, width) + Video data to save to file. It"s dimensions should be (duration, height, width) and data should be of uint8 type. The file extension determines the resultant file type. Alternatively, this can be a generator that yields frames of this - description, in which case 'dims' must also be passed. + description, in which case "dims" must also be passed. path : string / pathlib.Path object File to which the video will be saved. @@ -275,7 +281,7 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): The frame rate of the output video. dims : (int, int) - (height, width) of video. This is only needed if 'video' is a generator that + (height, width) of video. This is only needed if "video" is a generator that yields frames, as then the shape cannot be taken from it directly. """ @@ -288,8 +294,8 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): process = ( ffmpeg - .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=frame_rate) - .output(path.as_posix(), pix_fmt='yuv420p', r=frame_rate, crf=0, vcodec='libx264') + .input("pipe:", format="rawvideo", pix_fmt="rgb24", s=f"{width}x{height}", r=frame_rate) + .output(path.as_posix(), pix_fmt="yuv420p", r=frame_rate, crf=0, vcodec="libx264") .overwrite_output() .run_async(pipe_stdin=True) ) @@ -321,12 +327,12 @@ def read_hdf5(path): Returns ------- - pandas.DataFrame : The dataframe stored within the hdf5 file under the name 'df'. + pandas.DataFrame : The dataframe stored within the hdf5 file under the name "df". """ df = pd.read_hdf( path_or_buf=path, - key='df', + key="df", ) return df @@ -363,7 +369,7 @@ def write_hdf5(path, df, key="df", mode="w"): complib="blosc:lz4hc", ) - print('HDF5 saved to ', path) + print("HDF5 saved to ", path) return @@ -371,7 +377,7 @@ def write_hdf5(path, df, key="df", mode="w"): def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): """ Get a list of recording sessions for the specified mice, excluding those whose - metadata contain '"exclude" = True'. + metadata contain "'exclude' = True". Parameters ---------- @@ -398,13 +404,13 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): if not isinstance(mouse_ids, (list, tuple, set)): mouse_ids = [mouse_ids] sessions = {} - raw_dir = data_dir / 'raw' + raw_dir = data_dir / "raw" for mouse in mouse_ids: - mouse_sessions = list(raw_dir.glob(f'*{mouse}')) + mouse_sessions = list(raw_dir.glob(f"*{mouse}")) if not mouse_sessions: - print(f'Found no sessions for: {mouse}') + print(f"Found no sessions for: {mouse}") continue if not meta_dir: @@ -419,7 +425,7 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): )) continue - meta_file = meta_dir / (mouse + '.json') + meta_file = meta_dir / (mouse + ".json") with meta_file.open() as fd: mouse_meta = json.load(fd) # az: change date format into yyyymmdd @@ -433,15 +439,15 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): included_sessions = set() for i, session in enumerate(mouse_meta): try: - meta_date = datetime.datetime.strptime(session['date'], '%Y-%m-%d') + meta_date = datetime.datetime.strptime(session["date"], "%Y-%m-%d") except ValueError: # also allow this format - meta_date = datetime.datetime.strptime(session['date'], '%Y%m%d') + meta_date = datetime.datetime.strptime(session["date"], "%Y%m%d") except TypeError: raise PixelsError(f"{mouse} session #{i}: 'date' not found in JSON.") for index, ses_date in enumerate(session_dates): - if ses_date == meta_date and not session.get('exclude', False): + if ses_date == meta_date and not session.get("exclude", False): name = mouse_sessions[index].stem if name not in sessions: sessions[name] = [] @@ -452,9 +458,9 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): included_sessions.add(name) if included_sessions: - print(f'{mouse} has {len(included_sessions)} sessions:', ", ".join(included_sessions)) + print(f"{mouse} has {len(included_sessions)} sessions:", ", ".join(included_sessions)) else: - print(f'No session dates match between folders and metadata for: {mouse}') + print(f"No session dates match between folders and metadata for: {mouse}") return sessions @@ -564,8 +570,8 @@ def tdms_to_video(tdms_path, meta_path, output_path): process = ( ffmpeg - .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=fps) - .output(output_path.as_posix(), pix_fmt='yuv420p', r=fps, crf=20, vcodec='libx264') + .input("pipe:", format="rawvideo", pix_fmt="rgb24", s=f"{width}x{height}", r=fps) + .output(output_path.as_posix(), pix_fmt="yuv420p", r=fps, crf=20, vcodec="libx264") .overwrite_output() .run_async(pipe_stdin=True) ) From 24684405055c1785e5cb6ed94dbf283ddfd55127 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:34:22 +0000 Subject: [PATCH 143/658] change func names in line with base --- pixels/experiment.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index daaab98..7fc29b8 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -109,21 +109,21 @@ def set_cache(self, on): for session in self.sessions: session.set_cache(on) - def process_spikes(self): + def extract_ap(self): """ - Process the spike data from the raw neural recording data for all sessions. + Process the ap band from the raw neural recording data for all sessions. """ for i, session in enumerate(self.sessions): - print(">>>>> Processing spikes for session {} ({} / {})" + print(">>>>> Processing ap band for session {} ({} / {})" .format(session.name, i + 1, len(self.sessions))) - session.process_spikes() + session.extract_ap() - def sort_spikes(self, CatGT_app=None): + def sort_spikes(self, mc_method="dredge"): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): print(">>>>> Sorting spikes for session {} ({} / {})" .format(session.name, i + 1, len(self.sessions))) - session.sort_spikes(CatGT_app=CatGT_app) + session.sort_spikes(mc_method=mc_method) def assess_noise(self): """ @@ -134,14 +134,15 @@ def assess_noise(self): .format(session.name, i + 1, len(self.sessions))) session.assess_noise() - def process_lfp(self): + def extract_bands(self): """ - Process the LFP data from the raw neural recording data for all sessions. + Extract ap & lfp data from the raw neural recording data for all + sessions. """ for i, session in enumerate(self.sessions): - print(">>>>> Processing LFP data for session {} ({} / {})" + print(">>>>> Extracting ap & lfp data for session {} ({} / {})" .format(session.name, i + 1, len(self.sessions))) - session.process_lfp() + session.extract_bands() def process_behaviour(self): """ From 6106e0680eb20cfeb33f87b31ec949327f79b904 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:35:04 +0000 Subject: [PATCH 144/658] add constants needed for pipeline --- pixels/constants.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 pixels/constants.py diff --git a/pixels/constants.py b/pixels/constants.py new file mode 100644 index 0000000..f5127b6 --- /dev/null +++ b/pixels/constants.py @@ -0,0 +1,7 @@ +""" +This file contains some constants parameters for the pixels pipeline. +""" +freq_bands = { + "ap":[300, 9000], + "lfp":[0.5, 300], +} From 2dde484a3a0c005d5403adee0e156124025c0864 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:36:31 +0000 Subject: [PATCH 145/658] add imports --- pixels/behaviours/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e0caa2d..43ed297 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -33,6 +33,7 @@ import spikeinterface.exporters as sexp import spikeinterface.preprocessing as spre import spikeinterface.postprocessing as spost +import spikeinterface.qualitymetrics as sqm from scipy import interpolate from tables import HDF5ExtError from wavpack_numcodecs import WavPack @@ -40,6 +41,7 @@ from pixels import ioutils from pixels import signal from pixels.error import PixelsError +from pixels.constants import * if TYPE_CHECKING: from typing import Optional, Literal From 1b1bff090efb78de072a08ee28dec94e20369e1f Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:44:15 +0000 Subject: [PATCH 146/658] make sure sorter output can be found if they exist --- pixels/behaviours/base.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 43ed297..0b2d44e 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -198,21 +198,23 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self.files = ioutils.get_data_files(self.raw, name) - ks_outputs = sorted(glob.glob( + sorted_streams = sorted(glob.glob( str(self.processed) +'/' + f'sorted_stream_*' )) - self.ks_outputs = [None] * len(ks_outputs) - if not len(ks_outputs) == 0: - for stream_num, stream in enumerate(ks_outputs): + self.ks_outputs = [None] * len(sorted_streams) + if not len(sorted_streams) == 0: + for s, stream in enumerate(sorted_streams): path = Path(stream) + # use si sorting analyser + sa = self.files["pixels"]["imec0.ap"]["sorting_analyser"] if stream.split('_')[-2] == 'cat': if not ((path / 'phy_ks3').exists() and len(os.listdir(path / 'phy_ks3'))>17): - self.ks_outputs[stream_num] = path + self.ks_outputs[s] = path else: - self.ks_outputs[stream_num] = path / 'phy_ks3' - else: - self.ks_outputs[stream_num] = path + self.ks_outputs[s] = path / 'phy_ks3' + elif (path / sa).exists(): + self.ks_outputs[s] = path / "sorter_output" else: print(f"\n> {self.name} has not been spike-sorted.") From b6c58596f60eb12534dc6b4b82fe00212cb2456f Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:44:43 +0000 Subject: [PATCH 147/658] use new self.files structure --- pixels/behaviours/base.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0b2d44e..8c2e00e 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -251,18 +251,19 @@ def drop_data(self): """ Clear attributes that store data to clear some memory. """ - # assume each pixels session only has one behaviour session, no matter - # number of probes - #self._action_labels = None - self._action_labels = [None] * len(self.files) - self._behavioural_data = [None] * len(self.files) - self._spike_data = [None] * len(self.files) - self._spike_times_data = [None] * len(self.files) - self._spike_rate_data = [None] * len(self.files) - self._lfp_data = [None] * len(self.files) - self._motion_index = [None] * len(self.files) - self._cluster_info = [None] * len(self.files) - self._probe_depths = [None] * len(self.files) + # NOTE: number of behaviour session is independent of number of probes + self.stream_count = len(self.files["pixels"]) + self.behaviour_count = len(self.files["behaviour"]["action_labels"]) + + self._action_labels = [None] * self.behaviour_count + self._behavioural_data = [None] * self.behaviour_count + self._ap_data = [None] * self.stream_count + self._spike_times_data = [None] * self.stream_count + self._spike_rate_data = [None] * self.stream_count + self._lfp_data = [None] * self.stream_count + self._motion_index = [None] * self.behaviour_count + self._cluster_info = [None] * self.stream_count + self._probe_depths = [None] * self.stream_count self._load_lag() def set_cache(self, on: bool | Literal["overwrite"]) -> None: From 6660d1a341d9eccf15a78b78907b104646f7e863 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:46:44 +0000 Subject: [PATCH 148/658] use double quote --- pixels/behaviours/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 8c2e00e..179fda8 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -278,16 +278,16 @@ def _load_lag(self): Load previously-calculated lag information from a saved file if it exists, otherwise return Nones. """ - lag_file = self.processed / 'lag.json' - self._lag = [None] * len(self.files) + lag_file = self.processed / "lag.json" + self._lag = [None] * self.behaviour_count if lag_file.exists(): with lag_file.open() as fd: lag = json.load(fd) for rec_num, rec_lag in enumerate(lag): - if rec_lag['lag_start'] is None: + if rec_lag["lag_start"] is None: self._lag[rec_num] = None else: - self._lag[rec_num] = (rec_lag['lag_start'], rec_lag['lag_end']) + self._lag[rec_num] = (rec_lag["lag_start"], rec_lag["lag_end"]) def get_probe_depth(self): """ @@ -299,7 +299,7 @@ def get_probe_depth(self): continue if depth is None: try: - depth_file = self.processed / 'depth.txt' + depth_file = self.processed / "depth.txt" with depth_file.open() as fd: self._probe_depths[stream_num] = [float(line) for line in fd.readlines()][0] From acc981ac88b50d7dc5b73c140cfa19ca36b5d985 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:47:08 +0000 Subject: [PATCH 149/658] change spike_data to ap_data to make it more clear --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 179fda8..0395533 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -227,7 +227,7 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self._action_labels = None self._behavioural_data = None - self._spike_data = None + self._ap_data = None self._spike_times_data = None self._lfp_data = None self._lag = None @@ -1745,7 +1745,7 @@ def get_spike_data(self): """ Returns the processed and downsampled spike data. """ - return self._get_processed_data("_spike_data", "spike_processed") + return self._get_processed_data("_ap_data", "spike_processed") def get_lfp_data(self): """ From 64ec639329bc4f8a29b0f23d0fda0e575ef9bb61 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:49:07 +0000 Subject: [PATCH 150/658] add todo and notes --- pixels/behaviours/base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0395533..381c904 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -390,6 +390,9 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): The sync channel from either the spike or LFP data. """ + # TODO jan 14 2025: + # this func is not used in vr behaviour, since they are synched + # in vd.session print(" Finding lag between sync channels") recording = self.files[rec_num] @@ -455,6 +458,9 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): SYNC_BIN: int, number of rising sync edges for calculating the scaling factor. """ + # TODO jan 14 2025: + # this func does not work with the new self.files structure + edges_list = [] stream_ids = [] #self.CatGT_dir = Path(self.CatGT_dir[0]) @@ -575,6 +581,8 @@ def process_behaviour(self): """ Process behavioural data from raw tdms and align to neuropixels data. """ + # NOTE jan 14 2025: + # this func is not used by vr behaviour for rec_num, recording in enumerate(self.files): print( f">>>>> Processing behaviour for recording {rec_num + 1} of {len(self.files)}" @@ -845,6 +853,9 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: args: str, arguments in catgt. default is None. """ + # TODO jan 14 2025: + # this func is deprecated + assert 0, "deprecated" if CatGT_app == None: CatGT_app = "~/CatGT3.4" # move cwd to catgt From 6205f1e9cfead91b8643edb2550e6dce79890771 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:50:16 +0000 Subject: [PATCH 151/658] allows motion correction method setting; say process shanks separately when it is --- pixels/behaviours/base.py | 48 +++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 381c904..83c1bd4 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -631,22 +631,38 @@ def _preprocess_raw(self, rec, mc_method): Implementation of preprocessing on raw pixels data. """ # correct phase shift - print("> do phase shift correction on raw.") + print("> step 1: do phase shift correction.") rec_ps = spre.phase_shift(rec) - print("> do common average referencing.") - # NOTE: dtype will be converted to float32 during motion correction - cmr = spre.common_reference( + # remove bad channels from sorting + print("> step 2: remove bad channels.") + bad_chan_ids, chan_labels = spre.detect_bad_channels( rec_ps, + outside_channels_location="top", ) + labels, counts = np.unique(chan_labels, return_counts=True) + + for label, count in zip(labels, counts): + print(f"\t> Found {count} channels labelled as {label}.") + rec_clean = rec_ps.remove_channels(bad_chan_ids) - print(f"> correct motion with {mc_method}.") - mcd = spre.correct_motion( - cmr, - preset=mc_method, - #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + print("> step 3: do common median referencing.") + # NOTE: dtype will be converted to float32 during motion correction + cmr = spre.common_reference( + rec_clean, ) + if not mc_method == "ks": + print(f"> step 4: correct motion with {mc_method}.") + mcd = spre.correct_motion( + cmr, + preset=mc_method, + #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + ) + else: + print(f"> correct motion later with {mc_method}.") + mcd = cmr + return mcd def preprocess_raw(self, mc_method="dredge"): @@ -669,22 +685,26 @@ def preprocess_raw(self, mc_method="dredge"): # get pixels streams streams = self.files["pixels"] - for stream_id in streams: + for stream_id, stream_files in streams.items(): # check if exists - output = self.interim / streams[stream_id]['preprocessed'] + output = self.interim / stream_files["preprocessed"] if output.exists(): + print( + f"> Preprocessed data from {stream_id} loaded." + ) continue - # load si rec - rec = streams[stream_id]["si_rec"] + # load raw si rec + rec = stream_files["si_rec"] print( f">>>>> Preprocessing data for recording from {stream_id} " - f"of {len(streams)}" + f"in total of {self.stream_count} stream(s)" ) shank_groups = rec.get_channel_groups() if not np.all(shank_groups == shank_groups[0]): + print("> Preprocessing shanks separately.") preprocessed = [] # split by groups groups = rec.split_by("group") From 24304c88e109c708c6c4213fead180fc838ca163 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:50:53 +0000 Subject: [PATCH 152/658] add note --- pixels/behaviours/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 83c1bd4..fe0db09 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -715,6 +715,9 @@ def preprocess_raw(self, mc_method="dredge"): else: preprocessed = self._preprocess_raw(rec, mc_method) + # NOTE jan 16 2025: + # BUG: cannot set dtype back to int16, units from ks4 will have + # incorrect amp & loc preprocessed.save( format="zarr", folder=output, From 956e5a96e4ff1f52d05fcf21c06b52ff88a33bf3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:51:45 +0000 Subject: [PATCH 153/658] add func to estimate drift; compile process ap & lfp func into process_band that takes in args to define bandwidth --- pixels/behaviours/base.py | 157 +++++++++++--------------------------- 1 file changed, 45 insertions(+), 112 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index fe0db09..b603b3a 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -726,60 +726,63 @@ def preprocess_raw(self, mc_method="dredge"): return None + def estimate_drift(self): - def process_spikes(self): + from spikeinterface.sortingcomponents.peak_detection\ + import detect_peaks + from spikeinterface.sortingcomponents.peak_localization\ + import localize_peaks + # detect peaks + + return None + + + def extract_bands(self, bands=None): """ - Process the spike data from the raw neural recording data. + extract data of ap and lfp frequency bands from the raw neural recording + data. """ + if bands == None: + bands = freq_bands + # preprocess raw data self.preprocess_raw() - for rec_num, recording in enumerate(self.files): - print( - f">>>>> Processing spike data for recording {rec_num + 1} of {len(self.files)}" - ) - output = self.processed / recording['spike_processed'] - if output.exists(): - continue + streams = self.files["pixels"] + for stream_id, stream_files in streams.items(): + for name, freqs in bands.items(): + output = self.processed / stream_files[f"{name}_extracted"] + if output.exists(): + print(f"> {name} bands from {stream_id} loaded." + ) + continue + + print( + f">>>>> Extracting {name} bands from {stream_id} " + f"in total of {self.stream_count} stream(s)" + ) - # load preprocessed - preprocessed = self.find_file(recording['preprocessed']) - rec = si.load_extractor(preprocessed) + # load preprocessed + preprocessed = self.find_file(stream_files['preprocessed']) + rec = si.load_extractor(preprocessed) - print("> create ap band by high-pass filtering.") - ap_band = spre.bandpass_filter( - rec, - freq_min=300, - freq_max=9000, - ftype="butterworth", - ) + extracted = spre.bandpass_filter( + rec, + freq_min=freqs[0], + freq_max=freqs[1], + ftype="butterworth", + ) - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") - downsampled = spre.resample(ap_band, self.SAMPLE_RATE) + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + downsampled = spre.resample(extracted, self.SAMPLE_RATE) - downsampled.save( - format="zarr", - folder=output, - compressor=wv_compressor, - ) + downsampled.save( + format="zarr", + folder=output, + compressor=wv_compressor, + ) """ - orig_rate = self.spike_meta[rec_num]['imSampRate'] - num_chans = self.spike_meta[rec_num]['nSavedChans'] - - print("> Mapping spike data") - data = ioutils.read_bin(data_file, num_chans) - - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") - data = signal.resample(data, orig_rate, self.SAMPLE_RATE) - - # Ideally we would median subtract before downsampling, but that takes a - # very long time and is at risk of memory errors, so let's do it after. - print("> Performing median subtraction across rows") - data = signal.median_subtraction(data, axis=0) - print("> Performing median subtraction across columns") - data = signal.median_subtraction(data, axis=1) - if self._lag[rec_num] is None: self.sync_data(rec_num, sync_channel=data[:, -1]) lag_start, lag_end = self._lag[rec_num] @@ -793,76 +796,6 @@ def process_spikes(self): ioutils.write_hdf5(output, data) """ - def process_lfp(self): - """ - Process the LFP data from the raw neural recording data. - """ - # preprocess data - self.preprocess_raw() - - for rec_num, recording in enumerate(self.files): - print(f">>>>> Processing LFP for recording {rec_num + 1} of {len(self.files)}") - - output = self.processed / recording['lfp_processed'] - if output.exists(): - continue - assert 0 - - # load preprocessed - preprocessed = self.processed / recording['preprocessed'] - rec = se.load_extractor(preprocessed) - - # get lfp band - lfp_band = spre.bandpass_filter( - rec, - freq_min=0.5, - freq_max=300, - ftype="butterworth", - ) - - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") - downsampled = spre.resample(lfp_band, self.SAMPLE_RATE) - - downsampled.save( - format="zarr", - folder=output, - compressor=wv_compressor, - ) - assert 0 - - #if self._lag[rec_num] is None: - # self.sync_data(rec_num, sync_channel=data[:, -1]) - #lag_start, lag_end = self._lag[rec_num] - """ - - sd = self.processed / recording['lfp_sd'] - if sd.exists(): - continue - - SDs = np.std(traces, axis=0) - results = dict( - median=np.median(SDs), - SDs=SDs.tolist(), - ) - print(f"> Saving standard deviation (and their median) of each channel") - with open(sd, 'w') as fd: - json.dump(results, fd) - - #if lag_end < 0: - # data = data[:lag_end] - #if lag_start < 0: - # data = data[- lag_start:] - #print(f"> Saving median subtracted & downsampled LFP to {output}") - # save in .npy format - #np.save( - # file=output, - # arr=downsampled, - # allow_pickle=True, - #) - #downsampled = pd.DataFrame(downsampled) - #ioutils.write_hdf5(output, downsampled) - - def run_catgt(self, CatGT_app=None, args=None) -> None: """ This func performs CatGT on copied AP data in the interim. From 0e08706708991b690965fce37314ff7f7026c21a Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:53:08 +0000 Subject: [PATCH 154/658] loop items in dict to make it easier --- pixels/behaviours/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b603b3a..a1b1d85 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -875,8 +875,7 @@ def load_raw_ap(self): # if multiple runs for the same probe, concatenate them streams = self.files["pixels"] - for stream_id in streams: - stream_files = streams[stream_id] + for stream_id, stream_files in streams.items(): recs = [] for r, raw in enumerate(stream_files["ap_raw"]): try: From 13d16f9996f064fc469ea083f93efa9214bf2eaf Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:54:13 +0000 Subject: [PATCH 155/658] use kilosort 4, build sorting analyser, and export sorting report --- pixels/behaviours/base.py | 248 +++++++++++++++++++++++++++++--------- 1 file changed, 192 insertions(+), 56 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a1b1d85..e09f1fe 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -908,79 +908,215 @@ def load_raw_ap(self): return None - def sort_spikes(self, CatGT_app=None, old=False): + def sort_spikes(self, mc_method="dredge"): """ Run kilosort spike sorting on raw spike data. - """ - # preprocess raw - self.preprocess_raw() - assert 0 - # TODO jan 13 2025: - # CONTINUE HERE! - # put ks4 here - - #concat_rec, output = self.load_recording() - #assert 0 - #TODO: jan 3 see if ks can run normally now using load_recording() - - self.run_catgt(CatGT_app=CatGT_app) - - for _, files in enumerate(self.files): - if not CatGT_app == None: - print("\n> Sorting catgt-ed spikes.\n") - basename = self.CatGT_dir[0].split('/')[-1] - files['CatGT_ap_data'] = basename + "/" + files['CatGT_ap_data'] - files['CatGT_ap_meta'] = basename + "/" + files['CatGT_ap_meta'] - data_type = 'CatGT_ap_data' - meta_type = 'CatGT_ap_meta' - else: - print(f"\n> using the orignial spike data.\n") - data_type = 'spike_data' - meta_type = 'spike_meta' + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) - data_file = self.find_file(files[data_type]) - metadata = self.find_file(files[meta_type]) + """ + if not (self.interim.parent/"ks4_with_wavpack.sif").exists(): + raise PixelsError("Have you craeted Singularity image for sorting?") - stream_id = data_file.as_posix()[-12:-4] - if stream_id not in streams: - streams[stream_id] = metadata + # preprocess raw + self.preprocess_raw(mc_method=mc_method) - for stream_num, stream in enumerate(streams.items()): - stream_id, metadata = stream - # find spike sorting output folder - if len(re.findall('_t[0-9]+', data_file.as_posix())) == 0: - output = self.processed / f'sorted_stream_cat_{stream_num}' - else: - output = self.processed / f'sorted_stream_{stream_num}' + if mc_method == "ks": + ks_mc = True + else: + ks_mc = False + # set ks4 parameters + ks4_params = { + "do_correction": ks_mc, + "do_CAR": False, # do not common average reference + "save_preprocessed_copy": True, # save ks4 preprocessed data + } + + streams = self.files["pixels"] + for stream_id, stream_files in streams.items(): + stream_num = stream_id[-4] + assert 0 # check if already sorted and exported - for_phy = output / "phy_ks3" - if not for_phy.exists() or not len(os.listdir(for_phy)) > 1: + sa_dir = output / stream_files["sorting_analyser"] + if not sa_dir.exists() or not len(os.listdir(sa_dir)) > 1: print(f"> {self.name} {stream_id} not sorted/exported.\n") else: print("> Already sorted and exported, next session.\n") continue - try: - recording = se.SpikeGLXRecordingExtractor( - self.CatGT_dir[0], - stream_id=stream_id, + # load preprocessed + preprocessed = self.find_file(stream_files["preprocessed"]) + + # find spike sorting output folder + if not len(re.findall("_t[0-9]+", preprocessed.as_posix())) == 0: + output = self.processed / f"sorted_stream_cat_{stream_num}" + else: + output = self.processed / f"sorted_stream_{stream_num}" + + # load preprocessed rec + rec = si.load_extractor(preprocessed) + + # move current working directory to interim + os.chdir(self.interim) + + # run sorter + sorting = ss.run_sorter( + sorter_name="kilosort4", + recording=rec, + folder=output, + singularity_image=self.interim.parent/"ks4_with_wavpack.sif", + remove_existing_folder=True, + verbose=True, + **ks4_params, + ) + + # load ks preprocessed recording for sorting analyser + ks_preprocessed = se.read_binary( + file_paths=output/"sorter_output/temp_wh.dat", + sampling_frequency=rec.sampling_frequency, + dtype=np.int16, + num_channels=rec.get_num_channels(), + is_filtered=True, + ) + + # attach probe to ks4 preprocessed recording, from the raw + recording = ks_preprocessed.set_probe(rec.get_probe()) + # set properties to make sure sorting & sorting sa have all + # probe properties to form correct rec_attributes, esp `group` + recording._properties = rec._properties + + # >>> annotations >>> + annotations = ["probe_0_planar_contour", "probes_info", + "stream_name", "stream_id"] + for ann in annotations: + recording.set_annotation( + annotation_key=ann, + value=rec.get_annotation(ann), + overwrite=True, ) - # this recording is filtered - recording.annotate(is_filtered=True) - except ValueError as e: - raise PixelsError( - f"Did the raw data get fully copied to interim? Full error: {e}\n" + # original file info + recording.annotate(file_origin=str(preprocessed)) + # <<< annotations <<< + + # curate sorter output + # remove duplicate spikes + sorting = sc.remove_duplicated_spikes( + sorting, + censored_period_ms=0.3, + method="keep_first_iterative", + ) + # remove spikes exceeding recording number of samples + sorting = sc.remove_excess_spikes(sorting, recording) + + # remove redundant units created by ks + sorting = sc.remove_redundant_units( + sorting, + duplicate_threshold=0.9, # default is 0.8 + align=False, + remove_strategy="max_spikes", + ) + + # create sorting analyser + sa = si.create_sorting_analyzer( + sorting=sorting, + recording=recording, + ) + + # calculate all extensions BEFORE further steps + # list required extensions for redundant units removal and quality + # metrics + required_extensions = [ + "random_spikes", + "waveforms", + "templates", + "noise_levels", + "unit_locations", + "template_similarity", + "spike_amplitudes", + "correlograms" + ] + sa.compute( + required_extensions, + save=True, + ) + + # make sure to have group id for each unit + if not "group" in sa.sorting.get_property_keys(): + # get shank id, i.e., group + group = sa.recording.get_channel_groups() + # get max peak channel for each unit + max_chan = si.get_template_extremum_channel(sa).values() + # get group id for each unit + unit_group = group[list(max_chan)] + # set unit group as a property for sorting + sa.sorting.set_property( + key="group", + values=unit_group, ) - # concatenate recording segments - concat_rec = si.concatenate_recordings([recording]) - probe = pi.read_spikeglx(metadata.as_posix()) - concat_rec = concat_rec.set_probe(probe) - # annotate spike data is filtered - concat_rec.annotate(is_filtered=True) + # calculate quality metrics + qms = sqm.compute_quality_metrics(sa) + + # export pre curation report + sexp.export_report( + sorting_analyzer=sa, + output_folder=output/"report", + ) + + # >>> get depth of units on each shank >>> + # get probe geometry coordinates + coords = sa.get_probe().contact_positions + # get coordinates of max channel of each unit on probe, column 0 is x-axis, + # column 1 is y-axis/depth, 0 at bottom-left channel. + max_chan_coords = coords[sa.channel_ids_to_indices(max_chan)] + # set coordinates of max channel of each unit as a property of sorting + sa.sorting.set_property( + key="max_chan_coords", + values=max_chan_coords, + ) + # <<< get depth of units on each shank <<< + + # save pre-curated analyser to disk + sa.save_as( + format="zarr", + folder=output/"sa.zarr", + ) + # remove bad units + #rule = "sliding_rp_violation <= 0.1 & amplitude_median <= -50\ + # & amplitude_cutoff < 0.05 & sd_ratio < 1.5 & presence_ratio > 0.9\ + # & snr > 1.1 & rp_contamination < 0.2 & firing_rate > 0.1" + # use the ibl methods, but amplitude_cutoff rather than noise_cutoff + rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -50\ + & presence_ratio > 0.9" + good_qms = qms.query(rule) + # TODO nov 26 2024 + # wait till noise cutoff implemented and include that. + # also see why sliding rp violation gives loads nan. + + # get unit ids + curated_unit_ids = list(good_qms.index) + + # select curated + curated_sa = sa.select_units(curated_unit_ids) + + # save sa to disk + curated_sa.save_as( + format="zarr", + folder=sa_dir, + ) + + # export report + sexp.export_report( + sorting_analyzer=curated_sa, + output_folder=output/"curated_report", + ) + assert 0 if old: print("\n> loading old kilosort 3 results to spikeinterface") sorting = se.read_kilosort(old_ks_output_dir) # avoid re-sort old From d5f4d6225481f15a8e5ba15f8d4f3f794b14a82d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 16 Jan 2025 17:54:49 +0000 Subject: [PATCH 156/658] no need to calculate waveform separately or extract ks labels --- pixels/behaviours/base.py | 76 +-------------------------------------- 1 file changed, 1 insertion(+), 75 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e09f1fe..080a212 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1182,81 +1182,7 @@ def sort_spikes(self, mc_method="dredge"): ) print(f"> {self.name} {stream_id} waveforms extracted, now it is loaded.\n") except: - print(f"> {self.name} {stream_id} waveforms not extracted, extracting now.\n") - #if ks3_output.count_total_num_spikes() - # extract waveforms - waveforms = si.extract_waveforms( - recording=concat_rec, #recording=test, # for testing - sorting=ks3_output, - folder=cache, - load_if_exists=True, # load extracted if available - #load_if_exists=False, # re-calculate everytime - max_spikes_per_unit=500, # None will extract all waveforms - ms_before=2.0, # time before trough - ms_after=3.0, # time after trough - overwrite=False, - #overwrite=True, - **job_kwargs, - ) - - """ - # TODO: remove redundant units by keeping minimum shift, highest_amplitude, or - # max_spikes - ks3_output = sc.remove_redundant_units( - waveforms, # spike trains realigned using the peak shift in template - duplicate_threshold=0.9, # default is 0.8 - remove_strategy='minimum_shift', # keep unit with best peak alignment - ) - """ - # export to phy, with pc feature calculated. - # copy recording.dat to output so that individual waveforms can be - # seen in waveformview. - print(f"\n> Exporting {self.name} {stream_id} parameters for phy...\n") - sexp.export_to_phy( - waveform_extractor=waveforms, - output_folder=for_phy, - compute_pc_features=True, # pca - compute_amplitudes=True, - copy_binary=False, - #remove_if_exists=True, # overwrite everytime - remove_if_exists=False, # load if already exists - **job_kwargs, - ) - print(f"> Parameters for manual curation saved to {for_phy}.\n") - - correct_kslabels = for_phy / "cluster_KSLabel.tsv" - if correct_kslabels.exists(): - print(f"\nCorrect KS labels already saved in {correct_kslabels}. Next session.\n") - continue - - print("\n> Getting all KS labels...") - all_ks_labels = pd.read_csv( - output / "sorter_output/cluster_KSLabel.tsv", - sep='\t', - ) - print("\n> Finding cluster ids from spikeinterface output...") - new_clus_ids = pd.read_csv( - for_phy / "cluster_si_unit_ids.tsv", - sep='\t', - ) - units = new_clus_ids.si_unit_id.to_list() - - print("\n> Saving correct ks labels...") - selected_kslabels = all_ks_labels.iloc[units].reset_index(drop=True) - selected_kslabels.loc[:, "cluster_id"] = [i for i in range(new_clus_ids.shape[0])] - selected_kslabels.to_csv( - correct_kslabels, - sep='\t', - index=False, - ) - - # copy params.py from sorter_output to phy_ks3 - print(f"\n> Copying params.py to {for_phy}...") - copyfile(output / "sorter_output/params.py", for_phy / "params.py") - - # TODO jan 8 in sorter_output, only keep params, recording and - # temp_wh, delete the rest - print(f"\n> {self.name} {stream_id} spike-sorted.\n") + assert 0 def extract_videos(self, force=False): From 8771162febbfce9bcd600cea09aab40e2d20106e Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:07:46 +0000 Subject: [PATCH 157/658] add session name and stream id to name --- pixels/ioutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 332ad50..446fa35 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -110,10 +110,10 @@ def get_data_files(data_dir, session_name): # depth info of probe pixels[stream_id]["depth_info"] = base_name.with_name( - f"depth_info_{stream_id}.json" + f"{session_name}_{stream_id}_depth_info.h5" ) pixels[stream_id]["clustered_channels"] = base_name.with_name( - f"channel_clustering_results_{stream_id}.h5" + f"{session_name}_{stream_id}_channel_clustering_results.h5" ) # old catgt data From 21c6a7a66dc4c8c7fd0466bddda74f9d69bb9e49 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:08:03 +0000 Subject: [PATCH 158/658] add of_date arg --- pixels/ioutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 446fa35..2d4ad45 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -374,7 +374,7 @@ def write_hdf5(path, df, key="df", mode="w"): return -def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt): +def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): """ Get a list of recording sessions for the specified mice, excluding those whose metadata contain "'exclude' = True". From 8897b1b963dba1e8e8f47cbf3016fca8688b143e Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:08:38 +0000 Subject: [PATCH 159/658] sort session name and dates; allows selecting specific date --- pixels/ioutils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 2d4ad45..b4a8bdb 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -407,12 +407,25 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): raw_dir = data_dir / "raw" for mouse in mouse_ids: - mouse_sessions = list(raw_dir.glob(f"*{mouse}")) + mouse_sessions = sorted(list(raw_dir.glob(f"*{mouse}"))) if not mouse_sessions: print(f"Found no sessions for: {mouse}") continue + # allows different session date formats + session_dates = sorted([ + datetime.datetime.strptime(s.stem.split("_")[0], session_date_fmt) + for s in mouse_sessions + ]) + + if of_date is not None: + date_struct = datetime.datetime.strptime(of_date, session_date_fmt) + mouse_sessions = [mouse_sessions[session_dates.index(date_struct)]] + print(f"\n> Getting 1 session from {mouse} of " + f"{datetime.datetime.strftime(date_struct, '%Y %B %d')}." + ) + if not meta_dir: # Do not collect metadata for session in mouse_sessions: From 736a6ebaeecb2f847630648c7530ee0e650c8c92 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:09:41 +0000 Subject: [PATCH 160/658] put under else --- pixels/ioutils.py | 65 +++++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 36 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index b4a8bdb..d0f0105 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -437,43 +437,36 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): data_dir=data_dir, )) continue - - meta_file = meta_dir / (mouse + ".json") - with meta_file.open() as fd: - mouse_meta = json.load(fd) - # az: change date format into yyyymmdd - session_dates = [ - datetime.datetime.strptime(s.stem.split("_")[0], session_date_fmt) for s in mouse_sessions - ] - - if len(session_dates) != len(set(session_dates)): - raise PixelsError(f"{mouse}: Data folder dates must be unique.") - - included_sessions = set() - for i, session in enumerate(mouse_meta): - try: - meta_date = datetime.datetime.strptime(session["date"], "%Y-%m-%d") - except ValueError: - # also allow this format - meta_date = datetime.datetime.strptime(session["date"], "%Y%m%d") - except TypeError: - raise PixelsError(f"{mouse} session #{i}: 'date' not found in JSON.") - - for index, ses_date in enumerate(session_dates): - if ses_date == meta_date and not session.get("exclude", False): - name = mouse_sessions[index].stem - if name not in sessions: - sessions[name] = [] - sessions[name].append(dict( - metadata=session, - data_dir=data_dir, - )) - included_sessions.add(name) - - if included_sessions: - print(f"{mouse} has {len(included_sessions)} sessions:", ", ".join(included_sessions)) else: - print(f"No session dates match between folders and metadata for: {mouse}") + meta_file = meta_dir / (mouse + ".json") + with meta_file.open() as fd: + mouse_meta = json.load(fd) + + if len(session_dates) != len(set(session_dates)): + raise PixelsError(f"{mouse}: Data folder dates must be unique.") + + included_sessions = set() + for i, session in enumerate(mouse_meta): + try: + meta_date = datetime.datetime.strptime(session["date"], session_date_fmt) + except TypeError: + raise PixelsError(f"{mouse} session #{i}: 'date' not found in JSON.") + + for index, ses_date in enumerate(session_dates): + if ses_date == meta_date and not session.get("exclude", False): + name = mouse_sessions[index].stem + if name not in sessions: + sessions[name] = [] + sessions[name].append(dict( + metadata=session, + data_dir=data_dir, + )) + included_sessions.add(name) + + if included_sessions: + print(f"{mouse} has {len(included_sessions)} sessions:", ", ".join(included_sessions)) + else: + print(f"No session dates match between folders and metadata for: {mouse}") return sessions From 5a36e45d678647d49bdb94063857dd312526a729 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:10:13 +0000 Subject: [PATCH 161/658] allows specific date --- pixels/experiment.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 7fc29b8..cf9e981 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -52,6 +52,7 @@ def __init__( interim_dir=None, processed_dir=None, session_date_fmt="%y%m%d", + of_date=None, ): if not isinstance(mouse_ids, (list, tuple, set)): mouse_ids = [mouse_ids] @@ -71,10 +72,17 @@ def __init__( self.meta_dir = None self.sessions = [] - sessions = ioutils.get_sessions(mouse_ids, self.data_dir, self.meta_dir, session_date_fmt) + sessions = ioutils.get_sessions( + mouse_ids, + self.data_dir, + self.meta_dir, + session_date_fmt, + of_date, + ) for name, metadata in sessions.items(): - assert len(set(s['data_dir'] for s in metadata)) == 1, "All JSON items with same day must use same data folder." + assert len(set(s['data_dir'] for s in metadata)) == 1,\ + "All JSON items with same day must use same data folder." self.sessions.append( behaviour( name, From b4fd5ef33ac082867cb199d66f6eff248fda1812 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:10:34 +0000 Subject: [PATCH 162/658] reduce motion estimation window size to accommodate short shanks --- pixels/behaviours/base.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 080a212..2b168ce 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -631,11 +631,11 @@ def _preprocess_raw(self, rec, mc_method): Implementation of preprocessing on raw pixels data. """ # correct phase shift - print("> step 1: do phase shift correction.") + print("\t> step 1: do phase shift correction.") rec_ps = spre.phase_shift(rec) # remove bad channels from sorting - print("> step 2: remove bad channels.") + print("\t> step 2: remove bad channels.") bad_chan_ids, chan_labels = spre.detect_bad_channels( rec_ps, outside_channels_location="top", @@ -643,20 +643,27 @@ def _preprocess_raw(self, rec, mc_method): labels, counts = np.unique(chan_labels, return_counts=True) for label, count in zip(labels, counts): - print(f"\t> Found {count} channels labelled as {label}.") + print(f"\t\t> Found {count} channels labelled as {label}.") rec_clean = rec_ps.remove_channels(bad_chan_ids) - print("> step 3: do common median referencing.") + print("\t> step 3: do common median referencing.") # NOTE: dtype will be converted to float32 during motion correction cmr = spre.common_reference( rec_clean, ) if not mc_method == "ks": - print(f"> step 4: correct motion with {mc_method}.") + print(f"\t> step 4: correct motion with {mc_method}.") + # reduce spatial window size for four-shank + estimate_motion_kwargs = { + "win_step_um": 100, + "win_margin_um": -150, + } + mcd = spre.correct_motion( cmr, preset=mc_method, + estimate_motion_kwargs=estimate_motion_kwargs, #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, ) else: From 0e2580d8222980d7d9423c1562df44ae7ae38f6c Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:11:27 +0000 Subject: [PATCH 163/658] print shank id --- pixels/behaviours/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2b168ce..a6f0973 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -711,12 +711,12 @@ def preprocess_raw(self, mc_method="dredge"): shank_groups = rec.get_channel_groups() if not np.all(shank_groups == shank_groups[0]): - print("> Preprocessing shanks separately.") preprocessed = [] # split by groups groups = rec.split_by("group") - for group in groups.values(): - preprocessed.append(self._preprocess_raw(group)) + for g, group in enumerate(groups.values()): + print(f"> Preprocessing shank {g}") + preprocessed.append(self._preprocess_raw(group, mc_method)) # aggregate groups together preprocessed = si.aggregate_channels(preprocessed) else: From 0aa3a375d1058c4a6914d4e65b49e6cb916de4de Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:11:43 +0000 Subject: [PATCH 164/658] detect and localise peaks, and plot them --- pixels/behaviours/base.py | 65 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a6f0973..aed95a6 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -733,13 +733,74 @@ def preprocess_raw(self, mc_method="dredge"): return None - def estimate_drift(self): + def estimate_drift(self, rec, loc_method="monopolar_triangulation"): + + """ + Get a sense of possible drifts in the recordings by looking at a + "positional raster plot", i.e. the depth of the spike as function of + time. To do so, we need to detect the peaks, and then to localize them + in space. + + params + === + rec: spikeinterface recording extractor. + + loc_method: str, peak location method. + Default: "monopolar_triangulation" + list of methods: + "center_of_mass", "monopolar_triangulation", "grid_convolution" + to learn more, check: + https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html + """ from spikeinterface.sortingcomponents.peak_detection\ import detect_peaks from spikeinterface.sortingcomponents.peak_localization\ import localize_peaks - # detect peaks + + # step 1: detect peaks + peaks = detect_peaks( + recording=rec, + method="locally_exclusive", + detect_threshold=5, + exclude_sweep_ms=2, + radius_um=50., + ) + # step 2: localize the peaks to get a sense of their putative depths + peak_locations = localize_peaks( + recording=rec, + peaks=peaks, + method=loc_method, + ) + + # step 3: plot + fs = rec.sampling_frequency + fig, ax = plt.subplots( + ncols=2, + squeeze=False, + figsize=(5, 5), + sharey=True, + ) + ax[0, 0].scatter( + peaks["sample_index"] / fs, + peaks["y"], + color="k", + marker=".", + alpha=0.002, + ) + ax[0, 0].set_title(loc_method) + si.plot_probe_map(rec, ax=ax[0, 1]) + ax[0, 1].scatter( + peaks["x"], + peaks["y"], + color="purple", + alpha=0.002, + ) + + assert 0, "test needed" + stream_id = rec.stream_id + fig_name = f"{self.name}_{stream_id}_positional_raster_plot.pdf" + plt.imsave(self.processed/fig_name, fig) return None From 972efc965de8453200604996080f088317321def Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 20 Jan 2025 16:12:16 +0000 Subject: [PATCH 165/658] print session name too --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index aed95a6..582a82a 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -826,8 +826,8 @@ def extract_bands(self, bands=None): continue print( - f">>>>> Extracting {name} bands from {stream_id} " - f"in total of {self.stream_count} stream(s)" + f">>>>> Extracting {name} bands from {self.name} " + f"{stream_id} in total of {self.stream_count} stream(s)" ) # load preprocessed From 153e7a00ee6ccdafde78912354e8b4a8d6ec8ce3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:45:48 +0000 Subject: [PATCH 166/658] no need to get ks output dir at init --- pixels/behaviours/base.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 582a82a..547bd73 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -198,26 +198,6 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self.files = ioutils.get_data_files(self.raw, name) - sorted_streams = sorted(glob.glob( - str(self.processed) +'/' + f'sorted_stream_*' - )) - self.ks_outputs = [None] * len(sorted_streams) - if not len(sorted_streams) == 0: - for s, stream in enumerate(sorted_streams): - path = Path(stream) - # use si sorting analyser - sa = self.files["pixels"]["imec0.ap"]["sorting_analyser"] - if stream.split('_')[-2] == 'cat': - if not ((path / 'phy_ks3').exists() and - len(os.listdir(path / 'phy_ks3'))>17): - self.ks_outputs[s] = path - else: - self.ks_outputs[s] = path / 'phy_ks3' - elif (path / sa).exists(): - self.ks_outputs[s] = path / "sorter_output" - else: - print(f"\n> {self.name} has not been spike-sorted.") - self.CatGT_dir = sorted(glob.glob( str(self.interim) +'/' + f'catgt_{self.name}_g[0-9]' )) From f6c18d9f92f9c15b20cb5ba7f0d4ff8d008ec129 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:46:51 +0000 Subject: [PATCH 167/658] add potential way to change dtype --- pixels/behaviours/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 547bd73..9182f94 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -705,6 +705,7 @@ def preprocess_raw(self, mc_method="dredge"): # NOTE jan 16 2025: # BUG: cannot set dtype back to int16, units from ks4 will have # incorrect amp & loc + #preprocessed = spre.astype(preprocessed, dtype=np.int16) preprocessed.save( format="zarr", folder=output, From d934c22cf89221f42bfc67ed0ca7603977999b29 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:47:30 +0000 Subject: [PATCH 168/658] get stream_num too --- pixels/behaviours/base.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9182f94..89d0431 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -986,12 +986,9 @@ def sort_spikes(self, mc_method="dredge"): } streams = self.files["pixels"] - for stream_id, stream_files in streams.items(): - stream_num = stream_id[-4] - assert 0 - + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): # check if already sorted and exported - sa_dir = output / stream_files["sorting_analyser"] + sa_dir = self.find_file(stream_files["sorting_analyser"]) if not sa_dir.exists() or not len(os.listdir(sa_dir)) > 1: print(f"> {self.name} {stream_id} not sorted/exported.\n") else: From 4acb78a0625d8f29233235b6787000d240b7aaaf Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:47:41 +0000 Subject: [PATCH 169/658] output is the parent dir of sa_dir --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 89d0431..6d9f7bf 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1002,7 +1002,7 @@ def sort_spikes(self, mc_method="dredge"): if not len(re.findall("_t[0-9]+", preprocessed.as_posix())) == 0: output = self.processed / f"sorted_stream_cat_{stream_num}" else: - output = self.processed / f"sorted_stream_{stream_num}" + output = sa_dir.parent # load preprocessed rec rec = si.load_extractor(preprocessed) From 1bc65816a5e7785642825cff411d0a0fe1dc488d Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:49:03 +0000 Subject: [PATCH 170/658] loop through streams --- pixels/behaviours/base.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 6d9f7bf..682f2c3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1832,13 +1832,13 @@ def _get_si_spike_times(self): """ get spike times in second with spikeinterface """ - self.sa_dir = self.find_file(self.files[0]["sorting_analyser"]) - spike_times = self._spike_times_data - for stream_num, stream in enumerate(range(len(spike_times))): + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser - sa = si.load_sorting_analyzer(self.sa_dir) + sa = si.load_sorting_analyzer(sa_dir) times = {} # get spike train @@ -1854,9 +1854,10 @@ def _get_si_spike_times(self): axis=1, names="unit", ) + # get sampling frequency + fs = int(sa.sampling_frequency) # Convert to time into sample rate index - spike_times[stream_num] /= int(self.spike_meta[0]['imSampRate'])\ - / self.SAMPLE_RATE + spike_times[stream_num] /= fs / self.SAMPLE_RATE return spike_times[0] # NOTE: only deal with one stream for now From 3a96bc1810da3d04ebd823a18f8a19bbf679fc2c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:49:16 +0000 Subject: [PATCH 171/658] only get units from first stream for now --- pixels/behaviours/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 682f2c3..0550f84 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2269,9 +2269,11 @@ def select_units( """ if use_si: - self.sa_dir = self.find_file(self.files[0]["sorting_analyser"]) + # NOTE: only deal with one stream for now + stream_files = self.files["pixels"]["imec0.ap"] + sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser - sa = si.load_sorting_analyzer(self.sa_dir) + sa = si.load_sorting_analyzer(sa_dir) # get units unit_ids = sa.unit_ids From d613b893cf91e46245b925ce22ced6e1e4d3b689 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 12:49:52 +0000 Subject: [PATCH 172/658] sorting analyser is in `sorted_stream_N` --- pixels/ioutils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index d0f0105..f2f8c3f 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -95,9 +95,8 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["preprocessed"] = base_name.with_name( f"{session_name}_{stream_id}.preprocessed.zarr" ) - pixels[stream_id]["sorting_analyser"] = base_name.with_name( - f"curated_sa.zarr" - ) + pixels[stream_id]["sorting_analyser"] = base_name.parent/\ + f"sorted_stream_{stream_id[-4]}/curated_sa.zarr" # extracted ap stream, 300Hz+ pixels[stream_id]["ap_extracted"] = base_name.with_name( From 6b1bdd4c1e371dd1e3113c2b6a46ca1583b09625 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 18:56:11 +0000 Subject: [PATCH 173/658] use base_name --- pixels/ioutils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index f2f8c3f..d5517a1 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -93,18 +93,18 @@ def get_data_files(data_dir, session_name): # spikeinterface cache pixels[stream_id]["preprocessed"] = base_name.with_name( - f"{session_name}_{stream_id}.preprocessed.zarr" + f"{base_name.stem}.preprocessed.zarr" ) pixels[stream_id]["sorting_analyser"] = base_name.parent/\ f"sorted_stream_{stream_id[-4]}/curated_sa.zarr" # extracted ap stream, 300Hz+ pixels[stream_id]["ap_extracted"] = base_name.with_name( - f"{session_name}_{stream_id}.extracted.zarr" + f"{base_name.stem}.extracted.zarr" ) # extracted lfp stream, 300Hz- pixels[stream_id]["lfp_extracted"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}.lf.extracted.zarr" + f"{base_name.stem[:-3]}.lf.extracted.zarr" ) # depth info of probe From 87155967b6845090e42a8b6cafe38e78a453ef41 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Jan 2025 18:56:42 +0000 Subject: [PATCH 174/658] separate implementation; plot and save --- pixels/behaviours/base.py | 71 ++++++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0550f84..3664c6b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -714,8 +714,8 @@ def preprocess_raw(self, mc_method="dredge"): return None + def estimate_drift(self, rec, loc_method="monopolar_triangulation"): - """ Get a sense of possible drifts in the recordings by looking at a "positional raster plot", i.e. the depth of the spike as function of @@ -733,57 +733,96 @@ def estimate_drift(self, rec, loc_method="monopolar_triangulation"): to learn more, check: https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html """ + shank_groups = rec.get_channel_groups() + if not np.all(shank_groups == shank_groups[0]): + # split by groups + groups = rec.split_by("group") + chan_with_spikes = [] + for g, group in enumerate(groups.values()): + print(f"> Estimate drift of shank {g}") + chan_with_spikes.append( + self._estimate_drift(group, loc_method) + ) + else: + self._estimate_drift(rec, loc_method) + + assert 0 + np.concatenate(chan_with_spikes) + + def _estimate_drift(self, rec, loc_method="monopolar_triangulation"): + """ + implementation of drift estimation. + """ from spikeinterface.sortingcomponents.peak_detection\ import detect_peaks from spikeinterface.sortingcomponents.peak_localization\ import localize_peaks + import spikeinterface.widgets as sw - # step 1: detect peaks + print("> step 1: detect peaks") peaks = detect_peaks( recording=rec, - method="locally_exclusive", + method="by_channel", detect_threshold=5, exclude_sweep_ms=2, - radius_um=50., ) - # step 2: localize the peaks to get a sense of their putative depths + + print("> step 2: localize the peaks to get a sense of their putative " + "depths") peak_locations = localize_peaks( recording=rec, peaks=peaks, method=loc_method, ) + # TODO jan 22 2025 save peaks and plot later? + # save it as df + assert 0 # step 3: plot fs = rec.sampling_frequency fig, ax = plt.subplots( ncols=2, squeeze=False, - figsize=(5, 5), + figsize=(10, 10), sharey=True, ) + # plot peak time vs depth ax[0, 0].scatter( peaks["sample_index"] / fs, - peaks["y"], + peak_locations["y"], color="k", marker=".", alpha=0.002, ) - ax[0, 0].set_title(loc_method) - si.plot_probe_map(rec, ax=ax[0, 1]) + # plot peak locations on probe + sw.plot_probe_map(rec, ax=ax[0, 1]) ax[0, 1].scatter( - peaks["x"], - peaks["y"], + peak_locations["x"], + peak_locations["y"], color="purple", alpha=0.002, ) + y_max = rec.get_channel_locations()[:,1].max() + y_min = int(peak_locations["y"].min()) - 200 + y_min = np.min([y_min, -200]) + + ax[0, 0].set_title(loc_method) + ax[0, 0].set_xlabel("Time (ms)") + ax[0, 0].set_ylabel("Depth (um)") + ax[0, 0].set_ylim([y_min, y_max]) - assert 0, "test needed" - stream_id = rec.stream_id - fig_name = f"{self.name}_{stream_id}_positional_raster_plot.pdf" - plt.imsave(self.processed/fig_name, fig) + chan_idx = rec._parent_channel_indices + idx_with_spikes = np.unique(peaks["channel_index"]) + chan_with_spikes = rec.channel_ids[np.isin(chan_idx, idx_with_spikes)] - return None + stream_id = rec.channel_ids[0][:-4] + group_id = rec.get_channel_groups()[0] + fig_name = (f"{self.name}_{stream_id}_shank{group_id}_" + f"{loc_method}_positional_spike_raster_plot.png") + fig.savefig(self.processed/fig_name, dpi=300) + + return chan_with_spikes def extract_bands(self, bands=None): From d707195af206bfc2a7dc76bcee5d8d8a16154559 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:36:35 +0000 Subject: [PATCH 175/658] make downsampling a choice --- pixels/behaviours/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 3664c6b..76558ab 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -861,10 +861,13 @@ def extract_bands(self, bands=None): ftype="butterworth", ) - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") - downsampled = spre.resample(extracted, self.SAMPLE_RATE) + if downsample: + print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + band_data = spre.resample(extracted, self.SAMPLE_RATE) + else: + band_data = extracted - downsampled.save( + band_data.save( format="zarr", folder=output, compressor=wv_compressor, From d9f1e5aff8daf969acde882d76e7504f9fcbab5d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:37:07 +0000 Subject: [PATCH 176/658] add ks documentation --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 76558ab..f5b83ae 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -661,6 +661,7 @@ def preprocess_raw(self, mc_method="dredge"): mc_method: str, motion correction method. Default: "dredge". (as of jan 2025, dredge performs better than ks motion correction.) + "ks": let kilosort do motion correction. return === @@ -1008,7 +1009,7 @@ def sort_spikes(self, mc_method="dredge"): mc_method: str, motion correction method. Default: "dredge". (as of jan 2025, dredge performs better than ks motion correction.) - + "ks": do motion correction with kilosort. """ if not (self.interim.parent/"ks4_with_wavpack.sif").exists(): raise PixelsError("Have you craeted Singularity image for sorting?") From f49d45df70da4a6d65e1d15101dd069920e463aa Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:37:32 +0000 Subject: [PATCH 177/658] reduce sweeping time interval size --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f5b83ae..ad8c678 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -766,7 +766,7 @@ def _estimate_drift(self, rec, loc_method="monopolar_triangulation"): recording=rec, method="by_channel", detect_threshold=5, - exclude_sweep_ms=2, + exclude_sweep_ms=0.2, ) print("> step 2: localize the peaks to get a sense of their putative " From 1d08f5d3406afaad8d8639d5ccbd637d817381f5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:37:58 +0000 Subject: [PATCH 178/658] do not plot here in the main analysis pipeline --- pixels/behaviours/base.py | 99 +++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 46 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index ad8c678..fd72003 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -776,57 +776,64 @@ def _estimate_drift(self, rec, loc_method="monopolar_triangulation"): peaks=peaks, method=loc_method, ) - # TODO jan 22 2025 save peaks and plot later? - # save it as df - assert 0 - # step 3: plot + # get sampling frequency fs = rec.sampling_frequency - fig, ax = plt.subplots( - ncols=2, - squeeze=False, - figsize=(10, 10), - sharey=True, - ) - # plot peak time vs depth - ax[0, 0].scatter( - peaks["sample_index"] / fs, - peak_locations["y"], - color="k", - marker=".", - alpha=0.002, - ) - # plot peak locations on probe - sw.plot_probe_map(rec, ax=ax[0, 1]) - ax[0, 1].scatter( - peak_locations["x"], - peak_locations["y"], - color="purple", - alpha=0.002, - ) - y_max = rec.get_channel_locations()[:,1].max() - y_min = int(peak_locations["y"].min()) - 200 - y_min = np.min([y_min, -200]) - ax[0, 0].set_title(loc_method) - ax[0, 0].set_xlabel("Time (ms)") - ax[0, 0].set_ylabel("Depth (um)") - ax[0, 0].set_ylim([y_min, y_max]) - - chan_idx = rec._parent_channel_indices - idx_with_spikes = np.unique(peaks["channel_index"]) - chan_with_spikes = rec.channel_ids[np.isin(chan_idx, idx_with_spikes)] - - stream_id = rec.channel_ids[0][:-4] - group_id = rec.get_channel_groups()[0] - fig_name = (f"{self.name}_{stream_id}_shank{group_id}_" - f"{loc_method}_positional_spike_raster_plot.png") - fig.savefig(self.processed/fig_name, dpi=300) - - return chan_with_spikes + # save it as df + df_peaks = pd.DataFrame(peaks) + df_peak_locs = pd.DataFrame(peak_locations) + df = pd.concat([df_peaks, df_peak_locs], axis=1) + # add timestamps and channel ids + df["timestamp"] = df.sample_index / fs + df["channel_id"] = rec.get_channel_ids()[df.channel_index.values] + return df - def extract_bands(self, bands=None): + ## step 3: plot + # fig, ax = plt.subplots( + # ncols=2, + # squeeze=False, + # figsize=(10, 10), + # sharey=True, + # ) + # # plot peak time vs depth + # ax[0, 0].scatter( + # peaks["sample_index"] / fs, + # peak_locations["y"], + # color="k", + # marker=".", + # alpha=0.002, + # ) + # # plot peak locations on probe + # sw.plot_probe_map(rec, ax=ax[0, 1]) + # ax[0, 1].scatter( + # peak_locations["x"], + # peak_locations["y"], + # color="purple", + # alpha=0.002, + # ) + # y_max = rec.get_channel_locations()[:,1].max() + # y_min = int(peak_locations["y"].min()) - 200 + # y_min = np.min([y_min, -200]) + + # ax[0, 0].set_title(loc_method) + # ax[0, 0].set_xlabel("Time (ms)") + # ax[0, 0].set_ylabel("Depth (um)") + # ax[0, 0].set_ylim([y_min, y_max]) + + # chan_idx = rec._parent_channel_indices + # idx_with_spikes = np.unique(peaks["channel_index"]) + # chan_with_spikes = rec.channel_ids[np.isin(chan_idx, idx_with_spikes)] + + # stream_id = rec.channel_ids[0][:-4] + # group_id = rec.get_channel_groups()[0] + # fig_name = (f"{self.name}_{stream_id}_shank{group_id}_" + # f"{loc_method}_positional_spike_raster_plot.png") + # fig.savefig(self.processed/fig_name, dpi=300) + + + def extract_bands(self, bands=None, downsample=True): """ extract data of ap and lfp frequency bands from the raw neural recording data. From fe5a1fb4b43f7d90782064e361525574aed4fa08 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:38:37 +0000 Subject: [PATCH 179/658] add data category when try to find them --- pixels/behaviours/base.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index fd72003..1f3ea3b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -849,8 +849,7 @@ def extract_bands(self, bands=None, downsample=True): for name, freqs in bands.items(): output = self.processed / stream_files[f"{name}_extracted"] if output.exists(): - print(f"> {name} bands from {stream_id} loaded." - ) + print(f"> {name} bands from {stream_id} loaded.") continue print( @@ -1777,7 +1776,7 @@ def _extract_action_labels(self, behavioural_data): """ - def _get_processed_data(self, attr, key): + def _get_processed_data(self, attr, key, category): """ Used by the following get_X methods to load processed data. @@ -1794,18 +1793,21 @@ def _get_processed_data(self, attr, key): saved = getattr(self, attr) if saved[0] is None: - for rec_num, recording in enumerate(self.files): - if key in recording: - file_path = self.processed / recording[key] - if file_path.exists(): - if re.search(r'\.np[yz]$', file_path.suffix): - saved[rec_num] = np.load(file_path) - elif file_path.suffix == '.h5': - saved[rec_num] = ioutils.read_hdf5(file_path) - else: - msg = f"Could not find {attr[1:]} for recording {rec_num}." - msg += f"\nFile should be at: {file_path}" - raise PixelsError(msg) + files = self.files[category] + + if key in files: + dirs = files[key] + for f, file_dir in enumerate(dirs): + file_path = self.processed / file_dir + if file_path.exists(): + if re.search(r'\.np[yz]$', file_path.suffix): + saved[f] = np.load(file_path) + elif file_path.suffix == '.h5': + saved[f] = ioutils.read_hdf5(file_path) + else: + msg = f"Could not find {attr[1:]} for recording {rec_num}." + msg += f"\nFile should be at: {file_path}" + raise PixelsError(msg) return saved def get_action_labels(self): @@ -1815,7 +1817,8 @@ def get_action_labels(self): """ # TODO jul 5 2024: only one action label for a session, make sure it # does not error - return self._get_processed_data("_action_labels", "action_labels") + return self._get_processed_data("_action_labels", "action_labels", + "behaviour") def get_behavioural_data(self): """ From 071e6891878e1bceaccfc920d876fd195fb32df4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:39:04 +0000 Subject: [PATCH 180/658] get vr data from behaviour files --- pixels/behaviours/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 1f3ea3b..831c900 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2091,7 +2091,9 @@ def _get_aligned_trials( units = self.select_units() if not pos_bin is None: - vr_dir = self.find_file(self.files[0]['vr']) + behaviour_files = self.files["behaviour"] + # assume only one vr session for now + vr_dir = self.find_file(behaviour_files["vr_synched"][0]) vr_data = ioutils.read_hdf5(vr_dir) # get positions positions = vr_data.position_in_tunnel From f7e6b6dbdc4df4d58cb2eb84e03f220c017fecca Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 17:39:52 +0000 Subject: [PATCH 181/658] change names to make it more clear --- pixels/behaviours/base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 831c900..9dfe3c4 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2141,19 +2141,20 @@ def _get_aligned_trials( bin_frs = {} bin_counts = {} - for rec_num in range(len(self.files)): + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): # TODO jun 12 2024 skip other streams for now - if rec_num > 0: + if stream_num > 0: continue # Account for multiple raw data files - meta = self.spike_meta[rec_num] + meta = self.ap_meta[stream_num] samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 assert samples.is_integer() in_SAMPLE_RATE_scale = (samples * self.SAMPLE_RATE)\ - / int(self.spike_meta[0]['imSampRate']) + / int(self.ap_meta[0]['imSampRate']) cursor_duration = (cursor * self.SAMPLE_RATE)\ - / int(self.spike_meta[0]['imSampRate']) + / int(self.ap_meta[0]['imSampRate']) rec_spikes = spikes[ (cursor_duration <= spikes)\ & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) @@ -2162,8 +2163,8 @@ def _get_aligned_trials( # Account for lag, in case the ephys recording was started before the # behaviour - if not self._lag[rec_num] == None: - lag_start, _ = self._lag[rec_num] + if not self._lag[stream_num] == None: + lag_start, _ = self._lag[stream_num] else: lag_start = timestamps[0] From f6c284e607fca10651d6f39a35081b89e9309ce0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 20:14:10 +0000 Subject: [PATCH 182/658] add underscore --- pixels/ioutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index d5517a1..1cc4a5c 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -152,7 +152,7 @@ def get_data_files(data_dir, session_name): session_name + "motion_index.npz" ) behaviour["motion_tracking"] = base_name.with_name( - session_name + "motion_tracking.h5" + session_name + "_motion_tracking.h5" ) files = { From 1e893a399cbc500e189f92eba95752683c34730a Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 20:14:20 +0000 Subject: [PATCH 183/658] add detected peaks --- pixels/ioutils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 1cc4a5c..26664e9 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -95,6 +95,9 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["preprocessed"] = base_name.with_name( f"{base_name.stem}.preprocessed.zarr" ) + pixels[stream_id]["detected_peaks"] = base_name.with_name( + f"{base_name.stem}_detected_peaks.h5" + ) pixels[stream_id]["sorting_analyser"] = base_name.parent/\ f"sorted_stream_{stream_id[-4]}/curated_sa.zarr" From 3959d4d09a305d3ab0de77ca09e19593cd158e10 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 23 Jan 2025 20:14:49 +0000 Subject: [PATCH 184/658] do peak detection only on non-downsampled ap band --- pixels/behaviours/base.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9dfe3c4..4c03e7c 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -716,7 +716,8 @@ def preprocess_raw(self, mc_method="dredge"): return None - def estimate_drift(self, rec, loc_method="monopolar_triangulation"): + def estimate_drift(self, stream_files, + loc_method="monopolar_triangulation"): """ Get a sense of possible drifts in the recordings by looking at a "positional raster plot", i.e. the depth of the spike as function of @@ -734,21 +735,38 @@ def estimate_drift(self, rec, loc_method="monopolar_triangulation"): to learn more, check: https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html """ + output = self.processed / stream_files["detected_peaks"] + if output.exists(): + return ioutils.read_hdf5(output) + + self.preprocess_raw(mc_method="ks") + self.extract_bands(downsample=False) + + # get ap band + ap_file = self.find_file(stream_files["ap_extracted"]) + rec = si.load_extractor(ap_file) + shank_groups = rec.get_channel_groups() if not np.all(shank_groups == shank_groups[0]): # split by groups groups = rec.split_by("group") - chan_with_spikes = [] + dfs = [] for g, group in enumerate(groups.values()): - print(f"> Estimate drift of shank {g}") - chan_with_spikes.append( - self._estimate_drift(group, loc_method) - ) + print(f"\n> Estimate drift of shank {g}") + dfs.append(self._estimate_drift(group, loc_method)) + # concat shanks + df = pd.concat( + dfs, + axis=1, + keys=groups.keys(), + names=["shank", "spike_properties"] + ) else: - self._estimate_drift(rec, loc_method) + df = self._estimate_drift(rec, loc_method) - assert 0 - np.concatenate(chan_with_spikes) + ioutils.write_hdf5(output, df) + + return df def _estimate_drift(self, rec, loc_method="monopolar_triangulation"): From 43a927c86cd539502a2a9299567c8da47e77cbb9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 11:53:41 +0000 Subject: [PATCH 185/658] save motioned corrected to disk, preprocess on the fly --- pixels/ioutils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 26664e9..5574db1 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -83,6 +83,7 @@ def get_data_files(data_dir, session_name): "ap_raw": [], # there could be mutliple, thus list "ap_meta": [], "si_rec": None, # there could be only one, thus None + "preprocessed": None, "CatGT_ap_data": [], "CatGT_ap_meta": [], } @@ -92,8 +93,8 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["ap_meta"].append(original_name(ap_meta[r])) # spikeinterface cache - pixels[stream_id]["preprocessed"] = base_name.with_name( - f"{base_name.stem}.preprocessed.zarr" + pixels[stream_id]["motion_corrected"] = base_name.with_name( + f"{base_name.stem}.mcd.zarr" ) pixels[stream_id]["detected_peaks"] = base_name.with_name( f"{base_name.stem}_detected_peaks.h5" From 6afdd338606661f9970d25c4c9aaa900f605d0af Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 11:55:31 +0000 Subject: [PATCH 186/658] put pixels utils in a separate script --- pixels/behaviours/base.py | 280 ++++++++++---------------------------- pixels/pixels_utils.py | 209 ++++++++++++++++++++++++++++ 2 files changed, 279 insertions(+), 210 deletions(-) create mode 100644 pixels/pixels_utils.py diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4c03e7c..589f9a7 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -40,6 +40,7 @@ from pixels import ioutils from pixels import signal +import pixels.pixels_utils as xut from pixels.error import PixelsError from pixels.constants import * @@ -606,53 +607,57 @@ def process_behaviour(self): print("> Done!") - def _preprocess_raw(self, rec, mc_method): + def correct_motion(self, mc_method="dredge"): """ - Implementation of preprocessing on raw pixels data. + Correct motion of recording. + + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) + "ks": let kilosort do motion correction. + + return + === + None """ - # correct phase shift - print("\t> step 1: do phase shift correction.") - rec_ps = spre.phase_shift(rec) + if mc_method == "ks": + print(f"> Correct motion later with {mc_method}.") + return None - # remove bad channels from sorting - print("\t> step 2: remove bad channels.") - bad_chan_ids, chan_labels = spre.detect_bad_channels( - rec_ps, - outside_channels_location="top", - ) - labels, counts = np.unique(chan_labels, return_counts=True) + # get pixels streams + streams = self.files["pixels"] + + for stream_id, stream_files in streams.items(): + output = self.interim / stream_files["motion_corrected"] + if output.exists(): + print(f"> Motion corrected {stream_id} loaded.") + continue - for label, count in zip(labels, counts): - print(f"\t\t> Found {count} channels labelled as {label}.") - rec_clean = rec_ps.remove_channels(bad_chan_ids) + # preprocess raw recording + self.preprocess_raw() - print("\t> step 3: do common median referencing.") - # NOTE: dtype will be converted to float32 during motion correction - cmr = spre.common_reference( - rec_clean, - ) + # load preprocessed rec + rec = stream_files["preprocessed"] - if not mc_method == "ks": - print(f"\t> step 4: correct motion with {mc_method}.") - # reduce spatial window size for four-shank - estimate_motion_kwargs = { - "win_step_um": 100, - "win_margin_um": -150, - } - - mcd = spre.correct_motion( - cmr, - preset=mc_method, - estimate_motion_kwargs=estimate_motion_kwargs, - #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + print( + f">>>>> Correcting motion for recording from {stream_id} " + f"in total of {self.stream_count} stream(s) with {mc_method}" ) - else: - print(f"> correct motion later with {mc_method}.") - mcd = cmr - return mcd + mcd = xut.correct_motion(rec) - def preprocess_raw(self, mc_method="dredge"): + mcd.save( + format="zarr", + folder=output, + compressor=wv_compressor, + ) + + return None + + + def preprocess_raw(self): """ Preprocess full-band raw pixels data. @@ -674,50 +679,19 @@ def preprocess_raw(self, mc_method="dredge"): streams = self.files["pixels"] for stream_id, stream_files in streams.items(): - # check if exists - output = self.interim / stream_files["preprocessed"] - if output.exists(): - print( - f"> Preprocessed data from {stream_id} loaded." - ) - continue - # load raw si rec rec = stream_files["si_rec"] - print( f">>>>> Preprocessing data for recording from {stream_id} " f"in total of {self.stream_count} stream(s)" ) - shank_groups = rec.get_channel_groups() - if not np.all(shank_groups == shank_groups[0]): - preprocessed = [] - # split by groups - groups = rec.split_by("group") - for g, group in enumerate(groups.values()): - print(f"> Preprocessing shank {g}") - preprocessed.append(self._preprocess_raw(group, mc_method)) - # aggregate groups together - preprocessed = si.aggregate_channels(preprocessed) - else: - preprocessed = self._preprocess_raw(rec, mc_method) - - # NOTE jan 16 2025: - # BUG: cannot set dtype back to int16, units from ks4 will have - # incorrect amp & loc - #preprocessed = spre.astype(preprocessed, dtype=np.int16) - preprocessed.save( - format="zarr", - folder=output, - compressor=wv_compressor, - ) + stream_files["preprocessed"] = xut.preprocess_raw(rec) return None - def estimate_drift(self, stream_files, - loc_method="monopolar_triangulation"): + def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): """ Get a sense of possible drifts in the recordings by looking at a "positional raster plot", i.e. the depth of the spike as function of @@ -735,132 +709,41 @@ def estimate_drift(self, stream_files, to learn more, check: https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html """ - output = self.processed / stream_files["detected_peaks"] - if output.exists(): - return ioutils.read_hdf5(output) - - self.preprocess_raw(mc_method="ks") - self.extract_bands(downsample=False) - - # get ap band - ap_file = self.find_file(stream_files["ap_extracted"]) - rec = si.load_extractor(ap_file) - - shank_groups = rec.get_channel_groups() - if not np.all(shank_groups == shank_groups[0]): - # split by groups - groups = rec.split_by("group") - dfs = [] - for g, group in enumerate(groups.values()): - print(f"\n> Estimate drift of shank {g}") - dfs.append(self._estimate_drift(group, loc_method)) - # concat shanks - df = pd.concat( - dfs, - axis=1, - keys=groups.keys(), - names=["shank", "spike_properties"] - ) - else: - df = self._estimate_drift(rec, loc_method) - - ioutils.write_hdf5(output, df) - - return df + self.extract_bands("ap") + # get pixels streams + streams = self.files["pixels"] - def _estimate_drift(self, rec, loc_method="monopolar_triangulation"): - """ - implementation of drift estimation. - """ - from spikeinterface.sortingcomponents.peak_detection\ - import detect_peaks - from spikeinterface.sortingcomponents.peak_localization\ - import localize_peaks - import spikeinterface.widgets as sw + for stream_id, stream_files in streams.items(): + output = self.processed / stream_files["detected_peaks"] + if output.exists(): + print(f"> Peaks from {stream_id} already detected.") + continue - print("> step 1: detect peaks") - peaks = detect_peaks( - recording=rec, - method="by_channel", - detect_threshold=5, - exclude_sweep_ms=0.2, - ) + # get ap band + ap_file = self.find_file(stream_files["ap_extracted"]) + rec = si.load_extractor(ap_file) - print("> step 2: localize the peaks to get a sense of their putative " - "depths") - peak_locations = localize_peaks( - recording=rec, - peaks=peaks, - method=loc_method, - ) + # detect and localise peaks + df = xut.detect_n_localise_peaks(rec) - # get sampling frequency - fs = rec.sampling_frequency + # write to disk + ioutils.write_hdf5(output, df) - # save it as df - df_peaks = pd.DataFrame(peaks) - df_peak_locs = pd.DataFrame(peak_locations) - df = pd.concat([df_peaks, df_peak_locs], axis=1) - # add timestamps and channel ids - df["timestamp"] = df.sample_index / fs - df["channel_id"] = rec.get_channel_ids()[df.channel_index.values] + return None - return df - ## step 3: plot - # fig, ax = plt.subplots( - # ncols=2, - # squeeze=False, - # figsize=(10, 10), - # sharey=True, - # ) - # # plot peak time vs depth - # ax[0, 0].scatter( - # peaks["sample_index"] / fs, - # peak_locations["y"], - # color="k", - # marker=".", - # alpha=0.002, - # ) - # # plot peak locations on probe - # sw.plot_probe_map(rec, ax=ax[0, 1]) - # ax[0, 1].scatter( - # peak_locations["x"], - # peak_locations["y"], - # color="purple", - # alpha=0.002, - # ) - # y_max = rec.get_channel_locations()[:,1].max() - # y_min = int(peak_locations["y"].min()) - 200 - # y_min = np.min([y_min, -200]) - - # ax[0, 0].set_title(loc_method) - # ax[0, 0].set_xlabel("Time (ms)") - # ax[0, 0].set_ylabel("Depth (um)") - # ax[0, 0].set_ylim([y_min, y_max]) - - # chan_idx = rec._parent_channel_indices - # idx_with_spikes = np.unique(peaks["channel_index"]) - # chan_with_spikes = rec.channel_ids[np.isin(chan_idx, idx_with_spikes)] - - # stream_id = rec.channel_ids[0][:-4] - # group_id = rec.get_channel_groups()[0] - # fig_name = (f"{self.name}_{stream_id}_shank{group_id}_" - # f"{loc_method}_positional_spike_raster_plot.png") - # fig.savefig(self.processed/fig_name, dpi=300) - - - def extract_bands(self, bands=None, downsample=True): + def extract_bands(self, freqs=None): """ extract data of ap and lfp frequency bands from the raw neural recording data. """ - if bands == None: + if freqs == None: bands = freq_bands - - # preprocess raw data - self.preprocess_raw() + elif isinstance(freqs, str) and freqs in freq_bands.keys(): + bands = {freqs: freq_bands[freqs]} + elif isinstance(freqs, dict): + bands = freqs streams = self.files["pixels"] for stream_id, stream_files in streams.items(): @@ -992,34 +875,11 @@ def load_raw_ap(self): streams = self.files["pixels"] for stream_id, stream_files in streams.items(): - recs = [] - for r, raw in enumerate(stream_files["ap_raw"]): - try: - self.CatGT_dir = Path(self.CatGT_dir[0]) - data_dir = self.CatGT_dir - data_file = data_dir / stream_files['CatGT_ap_data'][r] - print("\n> Got catgt-ed recording.") - except: - print(f"\n> Getting the orignial recording...") - data_file = self.find_file(raw) - - # load recording file - rec = se.read_spikeglx( - folder_path=data_file.parent, - stream_id=stream_id, - stream_name=data_file.stem, - all_annotations=True, # include all annotations - ) - recs.append(rec) - - if len(recs) > 1: - # concatenate runs for each probe - concat_recs = si.concatenate_recordings(recs) - else: - concat_recs = recs[0] + # get paths of raw + paths = [self.find_file(path) for path in stream_files["ap_raw"]] # now the value for streams dict is recording extractor - stream_files["si_rec"] = concat_recs + stream_files["si_rec"] = xut.load_raw(paths, stream_id) return None diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py new file mode 100644 index 0000000..9061b04 --- /dev/null +++ b/pixels/pixels_utils.py @@ -0,0 +1,209 @@ +import numpy as np + +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.sorters as ss +import spikeinterface.curation as sc +import spikeinterface.exporters as sexp +import spikeinterface.preprocessing as spre +import spikeinterface.postprocessing as spost +import spikeinterface.qualitymetrics as sqm + +# set si job_kwargs +job_kwargs = dict( + n_jobs=0.8, # 80% core + chunk_duration='1s', + progress_bar=True, +) +si.set_global_job_kwargs(**job_kwargs) + +def load_raw(paths, stream_id): + """ + Load raw recording file from spikeglx. + """ + recs = [] + for p, path in enumerate(paths): + # NOTE: if it is catgt data, pass directly `catgt_ap_data` + print(f"\n> Getting the orignial recording...") + # load # recording # file + rec = se.read_spikeglx( + folder_path=path.parent, + stream_id=stream_id, + stream_name=path.stem, + all_annotations=True, # include # all # annotations + ) + recs.append(rec) + if len(recs) > 1: + # concatenate # runs # for # each # probe + concat_recs = si.concatenate_recordings(recs) + else: + concat_recs = recs[0] + + return rec + + +def preprocess_raw(rec): + shank_groups = rec.get_channel_groups() + if not np.all(shank_groups == shank_groups[0]): + preprocessed = [] + # split by groups + groups = rec.split_by("group") + for g, group in enumerate(groups.values()): + print(f"> Preprocessing shank {g}") + cleaned = _preprocess_raw(group) + preprocessed.append(cleaned) + # aggregate groups together + preprocessed = si.aggregate_channels(preprocessed) + else: + preprocessed = _preprocess_raw(rec) + + # NOTE jan 16 2025: + # BUG: cannot set dtype back to int16, units from ks4 will have + # incorrect amp & loc + if not preprocessed.dtype == np.dtype("int16"): + preprocessed = spre.astype(preprocessed, dtype=np.int16) + + return preprocessed + + +def _preprocess_raw(rec): + """ + Implementation of preprocessing on raw pixels data. + """ + # correct phase shift + print("\t> step 1: do phase shift correction.") + rec_ps = spre.phase_shift(rec) + + # remove bad channels from sorting + print("\t> step 2: remove bad channels.") + bad_chan_ids, chan_labels = spre.detect_bad_channels( + rec_ps, + outside_channels_location="top", + ) + labels, counts = np.unique(chan_labels, return_counts=True) + + for label, count in zip(labels, counts): + print(f"\t\t> Found {count} channels labelled as {label}.") + rec_clean = rec_ps.remove_channels(bad_chan_ids) + + print("\t> step 3: do common median referencing.") + # NOTE: dtype will be converted to float32 during motion correction + cmr = spre.common_reference( + rec_clean, + ) + + return cmr + + +def correct_motion(rec, mc_method="dredge"): + """ + Correct motion of recording. + + params + === + mc_method: str, motion correction method. + Default: "dredge". + (as of jan 2025, dredge performs better than ks motion correction.) + "ks": let kilosort do motion correction. + + return + === + None + """ + print(f"\t> correct motion with {mc_method}.") + # reduce spatial window size for four-shank + estimate_motion_kwargs = { + "win_step_um": 100, + "win_margin_um": -150, + } + + mcd = spre.correct_motion( + rec, + preset=mc_method, + estimate_motion_kwargs=estimate_motion_kwargs, + #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + ) + + # convert to int16 to save space + if not mcd.dtype == np.dtype("int16"): + mcd = spre.astype(mcd, dtype=np.int16) + + return mcd + + +def detect_n_localise_peaks(rec): + """ + Get a sense of possible drifts in the recordings by looking at a + "positional raster plot", i.e. the depth of the spike as function of + time. To do so, we need to detect the peaks, and then to localize them + in space. + + params + === + rec: spikeinterface recording extractor. + + loc_method: str, peak location method. + Default: "monopolar_triangulation" + list of methods: + "center_of_mass", "monopolar_triangulation", "grid_convolution" + to learn more, check: + https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html + """ + shank_groups = rec.get_channel_groups() + if not np.all(shank_groups == shank_groups[0]): + # split by groups + groups = rec.split_by("group") + dfs = [] + for g, group in enumerate(groups.values()): + print(f"\n> Estimate drift of shank {g}") + dfs.append(_estimate_drift(group, loc_method)) + # concat shanks + df = pd.concat( + dfs, + axis=1, + keys=groups.keys(), + names=["shank", "spike_properties"] + ) + else: + df = self._estimate_drift(rec, loc_method) + + return df + + +def _detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): + """ + implementation of drift estimation. + """ + from spikeinterface.sortingcomponents.peak_detection\ + import detect_peaks + from spikeinterface.sortingcomponents.peak_localization\ + import localize_peaks + + print("> step 1: detect peaks") + peaks = detect_peaks( + recording=rec, + method="by_channel", + detect_threshold=5, + exclude_sweep_ms=0.2, + ) + + print("> step 2: localize the peaks to get a sense of their putative " + "depths") + peak_locations = localize_peaks( + recording=rec, + peaks=peaks, + method=loc_method, + ) + + # get sampling frequency + fs = rec.sampling_frequency + + # save it as df + df_peaks = pd.DataFrame(peaks) + df_peak_locs = pd.DataFrame(peak_locations) + df = pd.concat([df_peaks, df_peak_locs], axis=1) + # add timestamps and channel ids + df["timestamp"] = df.sample_index / fs + df["channel_id"] = rec.get_channel_ids()[df.channel_index.values] + + return df From ddc166987123f1328434c6cbfee4b8d063fcd901 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 11:56:33 +0000 Subject: [PATCH 187/658] remove unused import --- pixels/behaviours/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 589f9a7..06a06a8 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -24,7 +24,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import probeinterface as pi import scipy.stats import spikeinterface as si import spikeinterface.extractors as se From a20766a1263e1134feaf28149f82b9bb9800ddf1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 12:32:27 +0000 Subject: [PATCH 188/658] put band extractor and spike sorting func in pixels utils --- pixels/pixels_utils.py | 201 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9061b04..ccf0c30 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -207,3 +207,204 @@ def _detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): df["channel_id"] = rec.get_channel_ids()[df.channel_index.values] return df + + +def extract_band(rec, freq_min, freq_max): + band = spre.bandpass_filter( + rec, + freq_min=freq_min, + freq_max=freq_min, + ftype="butterworth", + ) + + return band + + +def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): + sorting, recording = _sort_spikes( + rec, + output, + ks_image_path, + ks4_params, + ) + + sa, curated_sa = _curate_sorting( + sorting, + recording, + output, + ) + + _export_sorting_analyser( + sa, + curated_sa, + output, + curated_sa_dir, + ) + + return None + + +def _sort_spikes(rec, output, ks_image_path, ks4_params): + assert 0 + # run sorter + sorting = ss.run_sorter( + sorter_name="kilosort4", + recording=rec, + folder=output, + singularity_image=ks_image_path/"ks4_with_wavpack.sif", + remove_existing_folder=True, + verbose=True, + **ks4_params, + ) + + # load ks preprocessed recording for # sorting analyser + ks_preprocessed = se.read_binary( + file_paths=output/"sorter_output/temp_wh.dat", + sampling_frequency=rec.sampling_frequency, + dtype=np.int16, + num_channels=rec.get_num_channels(), + is_filtered=True, + ) + # attach probe # to ks4 preprocessed recording, from the raw + recording = ks_preprocessed.set_probe(rec.get_probe()) + # set properties to make sure sorting & sorting sa have all + # probe # properties to form correct rec_attributes, esp + recording._properties = rec._properties + + # >>> annotations >>> + annotations = rec.get_annotation_keys() + annotations.remove("is_filtered") + for ann in annotations: + recording.set_annotation( + annotation_key=ann, + value=rec.get_annotation(ann), + overwrite=True, + ) + # <<< annotations <<< + + return sorting, recording + + +def _curate_n_export(sorting, recording, output): + # curate sorter output + # remove spikes exceeding recording number of samples + sorting = sc.remove_excess_spikes(sorting, recording) + # remove duplicate spikes + sorting = sc.remove_duplicated_spikes( + sorting, + censored_period_ms=0.3, + method="keep_first_iterative", + ) + # remove redundant units created by ks + sorting = sc.remove_redundant_units( + sorting, + duplicate_threshold=0.9, # default is 0.8 + align=False, + remove_strategy="max_spikes", + ) + + # create sorting analyser + sa = si.create_sorting_analyzer( + sorting=sorting, + recording=recording, + ) + + # calculate all extensions BEFORE further steps + # list required extensions for redundant units removal and quality + # metrics + required_extensions = [ + "random_spikes", + "waveforms", + "templates", + "noise_levels", + "unit_locations", + "template_similarity", + "spike_amplitudes", + "correlograms", + "principal_components", # for # phy + ] + sa.compute(required_extensions, save=True) + + # make sure to have group id for each unit + if not "group" in sa.sorting.get_property_keys(): + # get shank id, i.e., group + group = sa.recording.get_channel_groups() + # get max peak channel for each unit + max_chan = si.get_template_extremum_channel(sa).values() + # get group id for each unit + unit_group = group[list(max_chan)] + # set unit group as a property for sorting + sa.sorting.set_property( + key="group", + values=unit_group, + ) + + # calculate quality metrics + qms = sqm.compute_quality_metrics(sa) + + # >>> get depth of units on each shank >>> + # get probe geometry coordinates + coords = sa.get_channel_locations() + # get coordinates of max channel of each unit on probe, column 0 is + # x-axis, column 1 is y-axis/depth, 0 at bottom-left channel. + max_chan_coords = coords[sa.channel_ids_to_indices(max_chan)] + # set coordinates of max channel of each unit as a property of sorting + sa.sorting.set_property( + key="max_chan_coords", + values=max_chan_coords, + ) + # <<< get depth of units on each shank <<< + + # remove bad units + #rule = "sliding_rp_violation <= 0.1 & amplitude_median <= -50\ + # & amplitude_cutoff < 0.05 & sd_ratio < 1.5 & presence_ratio > 0.9\ + # & snr > 1.1 & rp_contamination < 0.2 & firing_rate > 0.1" + # use the ibl methods, but amplitude_cutoff rather than noise_cutoff + rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -50\ + & presence_ratio > 0.9" + good_qms = qms.query(rule) + # TODO nov 26 2024 + # wait till noise cutoff implemented and include that. + # also see why sliding rp violation gives loads nan. + # get unit ids + curated_unit_ids = list(good_qms.index) + # select curated + curated_sa = sa.select_units(curated_unit_ids) + + return sa, curated_sa + + +def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): + + # export pre curation report + sexp.export_report( + sorting_analyzer=sa, + output_folder=output/"report", + ) + + # save pre-curated analyser to disk + sa.save_as( + format="zarr", + folder=output/"sa.zarr", + ) + + # export curated report + sexp.export_report( + sorting_analyzer=curated_sa, + output_folder=output/"curated_report", + ) + + # export to phy for additional manual curation if needed + sexp.export_to_phy( + sorting_analyzer=curated_sa, + output_folder=output/"phy", + copy_binary=False, + ) + + # save sa to disk + curated_sa.save_as( + format="zarr", + folder=curated_sa_dir, + ) + + return None From 50e637de3442b2ec510c933ecf8f13fcd1879e2c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 12:35:32 +0000 Subject: [PATCH 189/658] start print on a new line to make it more readable --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 06a06a8..fcfaec6 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -641,7 +641,7 @@ def correct_motion(self, mc_method="dredge"): rec = stream_files["preprocessed"] print( - f">>>>> Correcting motion for recording from {stream_id} " + f"\n>>>>> Correcting motion for recording from {stream_id} " f"in total of {self.stream_count} stream(s) with {mc_method}" ) @@ -681,7 +681,7 @@ def preprocess_raw(self): # load raw si rec rec = stream_files["si_rec"] print( - f">>>>> Preprocessing data for recording from {stream_id} " + f"\n>>>>> Preprocessing data for recording from {stream_id} " f"in total of {self.stream_count} stream(s)" ) From b60c1ddb56da91c18ec7b7208eb581bbd172dbf2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 12:36:38 +0000 Subject: [PATCH 190/658] remove unused import --- pixels/behaviours/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index fcfaec6..552ef37 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -30,8 +30,6 @@ import spikeinterface.sorters as ss import spikeinterface.curation as sc import spikeinterface.exporters as sexp -import spikeinterface.preprocessing as spre -import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm from scipy import interpolate from tables import HDF5ExtError From cb1e70b41d4f7cca780b58f8429df972ae489866 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 12:37:17 +0000 Subject: [PATCH 191/658] use extract_bands func in pixels utils --- pixels/behaviours/base.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 552ef37..12338d6 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -750,29 +750,26 @@ def extract_bands(self, freqs=None): print(f"> {name} bands from {stream_id} loaded.") continue + # preprocess raw data + self.preprocess_raw() + print( f">>>>> Extracting {name} bands from {self.name} " f"{stream_id} in total of {self.stream_count} stream(s)" ) # load preprocessed - preprocessed = self.find_file(stream_files['preprocessed']) - rec = si.load_extractor(preprocessed) + rec = stream_files["preprocessed"] - extracted = spre.bandpass_filter( + # do bandpass filtering + extracted = xut.extract_band( rec, freq_min=freqs[0], freq_max=freqs[1], - ftype="butterworth", ) - if downsample: - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") - band_data = spre.resample(extracted, self.SAMPLE_RATE) - else: - band_data = extracted - - band_data.save( + # write to disk + extracted.save( format="zarr", folder=output, compressor=wv_compressor, From cb375968bbe5828cc753647d21ee087cc770c8f5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 12:37:39 +0000 Subject: [PATCH 192/658] correct motion before spike sorting --- pixels/behaviours/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 12338d6..e5d6590 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -892,13 +892,14 @@ def sort_spikes(self, mc_method="dredge"): if not (self.interim.parent/"ks4_with_wavpack.sif").exists(): raise PixelsError("Have you craeted Singularity image for sorting?") - # preprocess raw - self.preprocess_raw(mc_method=mc_method) + # preprocess and motion correct raw + self.correct_motion(mc_method) if mc_method == "ks": ks_mc = True else: ks_mc = False + # set ks4 parameters ks4_params = { "do_correction": ks_mc, From ab9f8d9dedb3ab93c7029061a36c00b3fa0de3a7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:41:01 +0000 Subject: [PATCH 193/658] add documentation --- pixels/pixels_utils.py | 78 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index ccf0c30..71e7066 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -221,6 +221,29 @@ def extract_band(rec, freq_min, freq_max): def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): + """ + Sort spikes with kilosort 4, curate sorting, save sorting analyser to disk, + and export results to disk. + + params + === + rec: spikeinterface recording object. + + output: path object, directory of output. + + curated_sa_dir: path object, directory to save curated sorting analyser. + + ks_image_path: path object, directory of local kilosort 4 singularity image. + + ks4_params: dict, parameters for kilosort 4. + + return + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + """ + # sort spikes sorting, recording = _sort_spikes( rec, output, @@ -228,12 +251,14 @@ def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): ks4_params, ) + # curate sorting sa, curated_sa = _curate_sorting( sorting, recording, output, ) + # export sorting analyser _export_sorting_analyser( sa, curated_sa, @@ -245,7 +270,25 @@ def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): def _sort_spikes(rec, output, ks_image_path, ks4_params): - assert 0 + """ + Sort spikes with kilosort 4. + + params + === + rec: spikeinterface recording object. + + output: path object, directory of output. + + ks_image_path: path object, directory of local kilosort 4 singularity image. + + ks4_params: dict, parameters for kilosort 4. + + return + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + """ # run sorter sorting = ss.run_sorter( sorter_name="kilosort4", @@ -285,7 +328,24 @@ def _sort_spikes(rec, output, ks_image_path, ks4_params): return sorting, recording -def _curate_n_export(sorting, recording, output): +def _curate_sorting(sorting, recording, output): + """ + Curate spike sorting results, and export to disk. + + params + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + + output: path object, directory of output. + + return + === + sa: spikeinterface sorting analyser. + + curated_sa: curated spikeinterface sorting analyser. + """ # curate sorter output # remove spikes exceeding recording number of samples sorting = sc.remove_excess_spikes(sorting, recording) @@ -375,7 +435,21 @@ def _curate_n_export(sorting, recording, output): def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): + """ + Export sorting analyser to disk. + + params + === + sa: spikeinterface sorting analyser. + + curated_sa_dir: path object, directory to save curated sorting analyser. + output: path object, directory of output. + + return + === + None + """ # export pre curation report sexp.export_report( sorting_analyzer=sa, From c124d5829c5cb69186682b0cc0e28073e319cb9d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:45:28 +0000 Subject: [PATCH 194/658] point directly the kilosort image --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 71e7066..fc71611 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -294,7 +294,7 @@ def _sort_spikes(rec, output, ks_image_path, ks4_params): sorter_name="kilosort4", recording=rec, folder=output, - singularity_image=ks_image_path/"ks4_with_wavpack.sif", + singularity_image=ks_image_path, remove_existing_folder=True, verbose=True, **ks4_params, From ab686f41732cbde2bee61d3ed21c208e41cba4c4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:46:09 +0000 Subject: [PATCH 195/658] save ks4 image directory as variable --- pixels/behaviours/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e5d6590..c7479fc 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -889,7 +889,9 @@ def sort_spikes(self, mc_method="dredge"): (as of jan 2025, dredge performs better than ks motion correction.) "ks": do motion correction with kilosort. """ - if not (self.interim.parent/"ks4_with_wavpack.sif").exists(): + ks_image_path = self.interim.parent/"ks4_with_wavpack.sif" + + if not ks_image_path.exists(): raise PixelsError("Have you craeted Singularity image for sorting?") # preprocess and motion correct raw From d4502a83ea8a9281ce66435558986d64022c1671 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:48:35 +0000 Subject: [PATCH 196/658] save curated sorting analyser in `processed`; use sort_spikes in pixels utils --- pixels/behaviours/base.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c7479fc..f974779 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -912,28 +912,44 @@ def sort_spikes(self, mc_method="dredge"): streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): # check if already sorted and exported - sa_dir = self.find_file(stream_files["sorting_analyser"]) - if not sa_dir.exists() or not len(os.listdir(sa_dir)) > 1: + sa_dir = self.processed / stream_files["sorting_analyser"] + if not sa_dir.exists(): print(f"> {self.name} {stream_id} not sorted/exported.\n") else: print("> Already sorted and exported, next session.\n") continue - # load preprocessed - preprocessed = self.find_file(stream_files["preprocessed"]) + # get catgt directory + catgt_dir = self.find_file( + stream_files["CatGT_ap_data"][stream_num] + ) # find spike sorting output folder - if not len(re.findall("_t[0-9]+", preprocessed.as_posix())) == 0: - output = self.processed / f"sorted_stream_cat_{stream_num}" - else: + if catgt_dir is None: output = sa_dir.parent + else: + output = self.processed / f"sorted_stream_cat_{stream_num}" - # load preprocessed rec - rec = si.load_extractor(preprocessed) + # load rec + if ks_mc: + rec = stream_files["preprocessed"] + else: + rec_dir = self.find_file(stream_files["motion_corrected"]) + rec = si.load_extractor(rec_dir) # move current working directory to interim os.chdir(self.interim) + # sort spikes and save sorting analyser to disk + xut.sort_spikes( + rec=rec, + output=output, + curated_sa_dir=sa_dir, + ks_image_path=self.interim.parent, + ks4_params=ks4_params, + ) + assert 0 + # run sorter sorting = ss.run_sorter( sorter_name="kilosort4", From bd8f95b0237c89aedc6a55250cf8bf2155f8e972 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:54:26 +0000 Subject: [PATCH 197/658] keep all annotations but is_filtered --- pixels/behaviours/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f974779..404bd5e 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -977,16 +977,14 @@ def sort_spikes(self, mc_method="dredge"): recording._properties = rec._properties # >>> annotations >>> - annotations = ["probe_0_planar_contour", "probes_info", - "stream_name", "stream_id"] + annotations = rec.get_annotation_keys() + annotations.remove("is_filtered") for ann in annotations: recording.set_annotation( annotation_key=ann, value=rec.get_annotation(ann), overwrite=True, ) - # original file info - recording.annotate(file_origin=str(preprocessed)) # <<< annotations <<< # curate sorter output From 1dcc1cd2abc54dddc28d5476e744c3110f686131 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:55:19 +0000 Subject: [PATCH 198/658] calculate waveform pca for phy --- pixels/behaviours/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 404bd5e..9e7ee8a 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1023,6 +1023,7 @@ def sort_spikes(self, mc_method="dredge"): "template_similarity", "spike_amplitudes", "correlograms" + "principal_components", # for phy ] sa.compute( required_extensions, From 96d0c87e3bcaf3aadde1c87df0b51b14c83107de Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 30 Jan 2025 13:55:38 +0000 Subject: [PATCH 199/658] use direct method to get locations --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9e7ee8a..a4c3d14 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1055,7 +1055,7 @@ def sort_spikes(self, mc_method="dredge"): # >>> get depth of units on each shank >>> # get probe geometry coordinates - coords = sa.get_probe().contact_positions + coords = sa.get_channel_locations() # get coordinates of max channel of each unit on probe, column 0 is x-axis, # column 1 is y-axis/depth, 0 at bottom-left channel. max_chan_coords = coords[sa.channel_ids_to_indices(max_chan)] From d8fcce6699823ed03c076bc0bb42ec30e5cfc8a6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 16:58:41 +0000 Subject: [PATCH 200/658] move implementation to utils --- pixels/behaviours/base.py | 221 +------------------------------------- 1 file changed, 2 insertions(+), 219 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a4c3d14..aa243f0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -945,229 +945,12 @@ def sort_spikes(self, mc_method="dredge"): rec=rec, output=output, curated_sa_dir=sa_dir, - ks_image_path=self.interim.parent, + ks_image_path=ks_image_path, ks4_params=ks4_params, ) assert 0 - # run sorter - sorting = ss.run_sorter( - sorter_name="kilosort4", - recording=rec, - folder=output, - singularity_image=self.interim.parent/"ks4_with_wavpack.sif", - remove_existing_folder=True, - verbose=True, - **ks4_params, - ) - - # load ks preprocessed recording for sorting analyser - ks_preprocessed = se.read_binary( - file_paths=output/"sorter_output/temp_wh.dat", - sampling_frequency=rec.sampling_frequency, - dtype=np.int16, - num_channels=rec.get_num_channels(), - is_filtered=True, - ) - - # attach probe to ks4 preprocessed recording, from the raw - recording = ks_preprocessed.set_probe(rec.get_probe()) - # set properties to make sure sorting & sorting sa have all - # probe properties to form correct rec_attributes, esp `group` - recording._properties = rec._properties - - # >>> annotations >>> - annotations = rec.get_annotation_keys() - annotations.remove("is_filtered") - for ann in annotations: - recording.set_annotation( - annotation_key=ann, - value=rec.get_annotation(ann), - overwrite=True, - ) - # <<< annotations <<< - - # curate sorter output - # remove duplicate spikes - sorting = sc.remove_duplicated_spikes( - sorting, - censored_period_ms=0.3, - method="keep_first_iterative", - ) - # remove spikes exceeding recording number of samples - sorting = sc.remove_excess_spikes(sorting, recording) - - # remove redundant units created by ks - sorting = sc.remove_redundant_units( - sorting, - duplicate_threshold=0.9, # default is 0.8 - align=False, - remove_strategy="max_spikes", - ) - - # create sorting analyser - sa = si.create_sorting_analyzer( - sorting=sorting, - recording=recording, - ) - - # calculate all extensions BEFORE further steps - # list required extensions for redundant units removal and quality - # metrics - required_extensions = [ - "random_spikes", - "waveforms", - "templates", - "noise_levels", - "unit_locations", - "template_similarity", - "spike_amplitudes", - "correlograms" - "principal_components", # for phy - ] - sa.compute( - required_extensions, - save=True, - ) - - # make sure to have group id for each unit - if not "group" in sa.sorting.get_property_keys(): - # get shank id, i.e., group - group = sa.recording.get_channel_groups() - # get max peak channel for each unit - max_chan = si.get_template_extremum_channel(sa).values() - # get group id for each unit - unit_group = group[list(max_chan)] - # set unit group as a property for sorting - sa.sorting.set_property( - key="group", - values=unit_group, - ) - - # calculate quality metrics - qms = sqm.compute_quality_metrics(sa) - - # export pre curation report - sexp.export_report( - sorting_analyzer=sa, - output_folder=output/"report", - ) - - # >>> get depth of units on each shank >>> - # get probe geometry coordinates - coords = sa.get_channel_locations() - # get coordinates of max channel of each unit on probe, column 0 is x-axis, - # column 1 is y-axis/depth, 0 at bottom-left channel. - max_chan_coords = coords[sa.channel_ids_to_indices(max_chan)] - # set coordinates of max channel of each unit as a property of sorting - sa.sorting.set_property( - key="max_chan_coords", - values=max_chan_coords, - ) - # <<< get depth of units on each shank <<< - - # save pre-curated analyser to disk - sa.save_as( - format="zarr", - folder=output/"sa.zarr", - ) - - # remove bad units - #rule = "sliding_rp_violation <= 0.1 & amplitude_median <= -50\ - # & amplitude_cutoff < 0.05 & sd_ratio < 1.5 & presence_ratio > 0.9\ - # & snr > 1.1 & rp_contamination < 0.2 & firing_rate > 0.1" - # use the ibl methods, but amplitude_cutoff rather than noise_cutoff - rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -50\ - & presence_ratio > 0.9" - good_qms = qms.query(rule) - # TODO nov 26 2024 - # wait till noise cutoff implemented and include that. - # also see why sliding rp violation gives loads nan. - - # get unit ids - curated_unit_ids = list(good_qms.index) - - # select curated - curated_sa = sa.select_units(curated_unit_ids) - - # save sa to disk - curated_sa.save_as( - format="zarr", - folder=sa_dir, - ) - - # export report - sexp.export_report( - sorting_analyzer=curated_sa, - output_folder=output/"curated_report", - ) - assert 0 - if old: - print("\n> loading old kilosort 3 results to spikeinterface") - sorting = se.read_kilosort(old_ks_output_dir) # avoid re-sort old - # remove empty units - ks3_output = sorting.remove_empty_units() - print(f"> KS3 removed\ - \n{len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ - empty units.\n") - else: - try: - ks3_output = si.load_extractor(output / 'saved_si_sorting_obj') - #sorting_KS = read_kilosort(folder_path="kilosort-folder") - print(f"> {self.name} {stream_id} is already sorted, now it is loaded.\n") - - """ - # for testing: get first 5 mins of the recording - fs = concat_rec.get_sampling_frequency() - test = concat_rec.frame_slice( - start_frame=0*fs, - end_frame=300*fs, - ) - test.annotate(is_filtered=True) - # check all annotations - test.get_annotation('is_filtered') - print(test) - """ - - except: - print(f"> Now kilosorting {self.name} {stream_id}: \n{concat_rec}\n") - #ks3_output = ss.run_kilosort3(recording=concat_rec, output_folder=output) - sorting = ss.run_sorter( - sorter_name='kilosort3', - recording=concat_rec, #recording=test, # for testing - output_folder=output, - remove_existing_folder=True, - **job_kwargs, - ) - - # remove empty units - ks3_output = sorting.remove_empty_units() - print(f"> KS3 removed\ - \n{len(sorting.get_unit_ids()) - len(ks3_output.get_unit_ids())}\ - empty units.\n") - - """ - #TODO: remove duplicated spikes from spike train, only in >0.96.1 si - ks3_output = sc.remove_duplicated_spikes( - sorting=ks3_no_empt, - censored_period_ms=0.3, #ms - method='keep_first', # keep first spike, remove the second - ) - """ - # save spikeinterface sorting object for easier loading - ks3_output.save(folder=output / 'saved_si_sorting_obj') - - #TODO: toggle load_if_exists=True & overwrite=False should replace - #...load_from_folder. - cache = self.interim / f'cache_{stream_num}' - try: - waveforms = si.WaveformExtractor.load_from_folder( - folder=cache, - sorting=ks3_output, - ) - print(f"> {self.name} {stream_id} waveforms extracted, now it is loaded.\n") - except: - assert 0 + return None def extract_videos(self, force=False): From cb4a1126586ae7ca91a161cfaf53ae22e2cd768a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 17:00:24 +0000 Subject: [PATCH 201/658] add import --- pixels/pixels_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index fc71611..71e2944 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import spikeinterface as si import spikeinterface.extractors as se From 04c2a10f3ef68f3665af7bd6b6cf16f945c02597 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 17:00:40 +0000 Subject: [PATCH 202/658] allows loc_method definition --- pixels/pixels_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 71e2944..3d8923b 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -132,7 +132,7 @@ def correct_motion(rec, mc_method="dredge"): return mcd -def detect_n_localise_peaks(rec): +def detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): """ Get a sense of possible drifts in the recordings by looking at a "positional raster plot", i.e. the depth of the spike as function of @@ -157,7 +157,7 @@ def detect_n_localise_peaks(rec): dfs = [] for g, group in enumerate(groups.values()): print(f"\n> Estimate drift of shank {g}") - dfs.append(_estimate_drift(group, loc_method)) + dfs.append(_detect_n_localise_peaks(group, loc_method)) # concat shanks df = pd.concat( dfs, @@ -166,12 +166,12 @@ def detect_n_localise_peaks(rec): names=["shank", "spike_properties"] ) else: - df = self._estimate_drift(rec, loc_method) + df = self._detect_n_localise_peaks(rec, loc_method) return df -def _detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): +def _detect_n_localise_peaks(rec, loc_method): """ implementation of drift estimation. """ From 037e553c3b7a9189847c52ad73fdbde5be1f47d0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 17:01:03 +0000 Subject: [PATCH 203/658] add doc; allows different ftypes --- pixels/pixels_utils.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3d8923b..cba2d67 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -210,12 +210,31 @@ def _detect_n_localise_peaks(rec, loc_method): return df -def extract_band(rec, freq_min, freq_max): +def extract_band(rec, freq_min, freq_max, ftype="butter"): + """ + Band pass filter recording. + + params + === + freq_min: float, high-pass cutoff corner frequency. + + freq_max: float, low-pass cutoff corner frequency. + + ftype: str, filter type. + since its posthoc, we use 5th order acausal filter, and takes + second-order sections (SOS) representation of the filter. but more + filters to choose from, e.g., bessel with filter_order=2, presumably + preserves waveform better? see lussac. + + return + === + band: spikeinterface recording object. + """ band = spre.bandpass_filter( rec, freq_min=freq_min, freq_max=freq_min, - ftype="butterworth", + ftype=ftype, ) return band From 78955bda4141e3ffddc1872311c2e60bfcad35e7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 17:01:17 +0000 Subject: [PATCH 204/658] add notes --- pixels/pixels_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index cba2d67..73f3f59 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -263,6 +263,14 @@ def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): recording: spikeinterface recording object. """ + # NOTE: jan 30 2025 do we sort shanks separately??? + # if shanks are sorted separately, they will have separate sorter output, we + # will have to build an analyser for each group... + # maybe easier to just run all shanks together? + # the only way to concatenate four temp.dat and only create one sorting + # analyser is to read temp_wh.dat, set channels separately from raw, and + # si.aggregate_channels... + # sort spikes sorting, recording = _sort_spikes( rec, From d14116c37df16791f314e9745976718200a2e94d Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 17:01:32 +0000 Subject: [PATCH 205/658] add func to sort shank separately --- pixels/pixels_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 73f3f59..d8be123 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -317,6 +317,18 @@ def _sort_spikes(rec, output, ks_image_path, ks4_params): recording: spikeinterface recording object. """ + # run sorter per shank + #sorting = ss.run_sorter_by_property( + # sorter_name='kilosort4', + # recording=rec, + # grouping_property="group", + # folder=output, + # singularity_image=ks_image_path, + # remove_existing_folder=True, + # verbose=True, + # **ks4_params, + #) + # run sorter sorting = ss.run_sorter( sorter_name="kilosort4", From 974630217d30f2ee4b7b41218dfb621c0ac62f13 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 21 Feb 2025 17:01:45 +0000 Subject: [PATCH 206/658] do not calculate pca cuz it crashes, do it when needed later --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d8be123..6408933 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -421,7 +421,7 @@ def _curate_sorting(sorting, recording, output): "template_similarity", "spike_amplitudes", "correlograms", - "principal_components", # for # phy + #"principal_components", # for # phy ] sa.compute(required_extensions, save=True) From aac408fa55ffa06a35c6eda47dc82b4524619063 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:17:26 +0000 Subject: [PATCH 207/658] if it already exists then return it --- pixels/behaviours/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index aa243f0..5420913 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1467,6 +1467,8 @@ def _get_processed_data(self, attr, key, category): if saved[0] is None: files = self.files[category] + else: + return saved if key in files: dirs = files[key] From 257f50c1c0b4c6ff129378330c0b71d9d0463146 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:17:46 +0000 Subject: [PATCH 208/658] remove todo; get the first action label --- pixels/behaviours/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 5420913..5be31f7 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1483,6 +1483,7 @@ def _get_processed_data(self, attr, key, category): msg = f"Could not find {attr[1:]} for recording {rec_num}." msg += f"\nFile should be at: {file_path}" raise PixelsError(msg) + return saved def get_action_labels(self): @@ -1490,8 +1491,6 @@ def get_action_labels(self): Returns the action labels, either from self._action_labels if they have been loaded already, or from file. """ - # TODO jul 5 2024: only one action label for a session, make sure it - # does not error return self._get_processed_data("_action_labels", "action_labels", "behaviour") @@ -1650,7 +1649,7 @@ def _get_aligned_spike_times( align_trials delegates to this function, and should be used for getting aligned data in scripts. """ - action_labels = self.get_action_labels() + action_labels = self.get_action_labels()[0] if units is None: units = self.select_units() From 32c02089b64a412122f5c7e5ed388bb010f85df2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:19:10 +0000 Subject: [PATCH 209/658] separate binning into a separate func --- pixels/behaviours/base.py | 73 +++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 5be31f7..b43a17e 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1742,7 +1742,6 @@ def _get_aligned_spike_times( def _get_aligned_trials( self, label, event, units=None, sigma=None, end_event=None, - time_bin=None, pos_bin=None, ): """ Returns spike rate for each unit within a trial. @@ -1900,42 +1899,42 @@ def _get_aligned_trials( rec_trials[trial_ids[i]] = rates trial_positions[trial_ids[i]] = trial_pos - # get bin firing rates - bin_frs[i] = self.bin_vr_trial( - data=rates, - positions=trial_pos, - time_bin=time_bin, - pos_bin=pos_bin, - bin_method="mean", - ) - # get bin spike count - bin_counts[i] = self.bin_vr_trial( - data=spiked, - positions=trial_pos, - time_bin=time_bin, - pos_bin=pos_bin, - bin_method="sum", - ) - - # stack df values into np array - # reshape into trials x units x bins - bin_count_arr = ioutils.reindex_by_longest(bin_counts).T - bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T - - # save bin_fr and bin_count, for alfredo & andrew - # use label as array key name - fr_to_save = { - "fr": bin_fr_arr[:, :-2, :], - "pos": bin_fr_arr[:, -2:, :], - } - np.savez_compressed(output_fr_path, **fr_to_save) - print(f"> Output saved at {output_fr_path}.") - count_to_save = { - "count": bin_count_arr[:, :-2, :], - "pos": bin_count_arr[:, -2:, :], - } - np.savez_compressed(output_count_path, **count_to_save) - print(f"> Output saved at {output_count_path}.") + ## get bin firing rates + #bin_frs[i] = self.bin_vr_trial( + # data=rates, + # positions=trial_pos, + # time_bin=time_bin, + # pos_bin=pos_bin, + # bin_method="mean", + #) + ## get bin spike count + #bin_counts[i] = self.bin_vr_trial( + # data=spiked, + # positions=trial_pos, + # time_bin=time_bin, + # pos_bin=pos_bin, + # bin_method="sum", + #) + + ## stack df values into np array + ## reshape into trials x units x bins + #bin_count_arr = ioutils.reindex_by_longest(bin_counts).T + #bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T + + ## save bin_fr and bin_count, for alfredo & andrew + ## use label as array key name + #fr_to_save = { + # "fr": bin_fr_arr[:, :-2, :], + # "pos": bin_fr_arr[:, -2:, :], + #} + #np.savez_compressed(output_fr_path, **fr_to_save) + #print(f"> Output saved at {output_fr_path}.") + #count_to_save = { + # "count": bin_count_arr[:, :-2, :], + # "pos": bin_count_arr[:, -2:, :], + #} + #np.savez_compressed(output_count_path, **count_to_save) + #print(f"> Output saved at {output_count_path}.") if not rec_trials: return None From fcb6502cf2b646037092ad690ee384e78b34cfc5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:19:33 +0000 Subject: [PATCH 210/658] add unit_kwargs to allow separate shanks --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b43a17e..7a8d9c4 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1958,7 +1958,8 @@ def _get_aligned_trials( def select_units( self, group='good', min_depth=0, max_depth=None, min_spike_width=None, - max_spike_width=None, uncurated=False, name=None, use_si=False, + unit_kwargs=None, max_spike_width=None, uncurated=False, name=None, + use_si=False, ): """ Select units based on specified criteria. The output of this can be passed to From c0bc894c8fb99bc9c2a634dfd8bee418928360b2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:21:47 +0000 Subject: [PATCH 211/658] gets units from separate shanks --- pixels/behaviours/base.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7a8d9c4..d812a38 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2013,13 +2013,34 @@ def select_units( if name is not None: selected_units.name = name + # get shank id for units + shank_ids = sa.sorting.get_property("group") + # get coordinates of channel with max. amplitude max_chan_coords = sa.sorting.get_property("max_chan_coords") # get depths depths = max_chan_coords[:, 1] - # select units within depths range - in_range = unit_ids[(depths >= min_depth) & (depths < max_depth)] - selected_units.extend(in_range) + + if unit_kwargs: + for shank_id, kwargs in unit_kwargs.items(): + # get shank depths + min_depth = kwargs["min_depth"] + max_depth = kwargs["max_depth"] + # find units + in_range = unit_ids[ + (depths >= min_depth) & (depths < max_depth) &\ + (shank_ids == shank_id) + ] + # add to list + selected_units.extend(in_range) + else: + # if there is only one shank + # find units + in_range = unit_ids[ + (depths >= min_depth) & (depths < max_depth) + ] + # add to list + selected_units.extend(in_range) return selected_units From a36d576a9306bbb8750b29f91d30903fdc0a5ce8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:25:54 +0000 Subject: [PATCH 212/658] no need to take time/position bin for aligning trials --- pixels/behaviours/base.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d812a38..74f5dd2 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2138,7 +2138,7 @@ def get_lfp_data_raw(self): def align_trials( self, label, event, units=None, data='spike_times', raw=False, duration=1, sigma=None, dlc_project=None, video_match=None, - end_event=None, time_bin=None, pos_bin=False, + end_event=None, ): """ Get trials aligned to an event. This finds all instances of label in the action @@ -2184,12 +2184,6 @@ def align_trials( end_event : int | None For VR behaviour, when aligning to the whole trial, this param is the end event to align to. - - time_bin: str | None - For VR behaviour, size of temporal bin for spike rate data. - - pos_bin: int | None - For VR behaviour, size of positional bin for position data. """ data = data.lower() @@ -2931,8 +2925,7 @@ def bin_vr_trial(self, data, positions, time_bin, pos_bin, bin_method="mean"): @_cacheable def get_positional_rate( - self, label, event, end_event=None, sigma=None, time_bin=None, - pos_bin=None, units=None, + self, label, event, end_event=None, sigma=None, units=None, ): """ Get positional firing rate of selected units in vr, and spatial @@ -2957,8 +2950,6 @@ def get_positional_rate( data="trial_rate", sigma=sigma, end_event=end_event, - time_bin=time_bin, - pos_bin=pos_bin, ) fr = trials["fr"] positions = trials["positions"] From decc33068230ecf6fd71239d120eeab7e3c6a407 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:26:19 +0000 Subject: [PATCH 213/658] print a bit more info --- pixels/behaviours/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 74f5dd2..f29078f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2210,11 +2210,10 @@ def align_trials( ) if data == "trial_rate": - print(f"Aligning {data} to trials.") + print(f"Aligning {data} of {units} units to trials.") # we let a dedicated function handle aligning spike times return self._get_aligned_trials( label, event, units=units, sigma=sigma, end_event=end_event, - time_bin=time_bin, pos_bin=pos_bin, ) if data == "motion_tracking" and not dlc_project: From 54d9a35ea48b3431823eadb16aab640a47791e3c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 26 Feb 2025 17:55:49 +0000 Subject: [PATCH 214/658] put binning aligned trials in a separate func --- pixels/behaviours/base.py | 209 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f29078f..4babab2 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3031,5 +3031,214 @@ def get_positional_rate( return {"pos_fr": pos_fr, "occupancy": occupancy} + + def bin_aligned_trials( + self, label, event, units=None, sigma=None, end_event=None, + time_bin=None, pos_bin=None, + ): + """ + Returns spike rate for each unit within a trial. + align_trials delegates to this function, and should be used for getting aligned + data in scripts. + + This function also saves binned data in the format that Andrew wants: + trials * units * temporal bins (ms) + + time_bin: str | None + For VR behaviour, size of temporal bin for spike rate data. + + pos_bin: int | None + For VR behaviour, size of positional bin for position data. + + """ + action_labels = self.get_action_labels()[0] + + # define output path for binned spike rate + output_fr_path = self.interim/\ + f'cache/{self.name}_{label}_{units}_{time_bin}_spike_rate.npz' + output_count_path = self.interim/\ + f'cache/{self.name}_{label}_{units}_{time_bin}_spike_count.npz' + + if output_count_path.exists() and output_fr_path.exists(): + print(f"> {self.name} {label} {units} {time_bin} .npz already " + "saved.") + return None + + print(f"> Binning data from {self.name} {label} {units} units to " + f"{time_bin}.") + if units is None: + units = self.select_units() + if not pos_bin is None: + behaviour_files = self.files["behaviour"] + # assume only one vr session for now + vr_dir = self.find_file(behaviour_files["vr_synched"][0]) + vr_data = ioutils.read_hdf5(vr_dir) + # get positions + positions = vr_data.position_in_tunnel + + #TODO: with multiple streams, spike times will be a list with multiple dfs, + #make sure old code does not break! + spikes = self.get_spike_times(use_si=True)[units] + # drop rows if all nans + spikes = spikes.dropna(how="all") + + # since each session has one behaviour session, now only one action + # label file + actions = action_labels["outcome"] + events = action_labels["events"] + # get timestamps index of behaviour in self.SAMPLE_RATE hz, to convert + # it to ms, do timestamps*1000/self.SAMPLE_RATE + timestamps = action_labels["timestamps"] + + # select frames of wanted trial type + trials = np.where(np.bitwise_and(actions, label))[0] + # map starts by event + starts = np.where(np.bitwise_and(events, event))[0] + # map starts by end event + ends = np.where(np.bitwise_and(events, end_event))[0] + + # only take starts from selected trials + selected_starts = trials[np.where(np.isin(trials, starts))[0]] + start_t = timestamps[selected_starts] + # only take ends from selected trials + selected_ends = trials[np.where(np.isin(trials, ends))[0]] + end_t = timestamps[selected_ends] + + # use original trial id as trial index + trial_ids = vr_data.iloc[selected_starts].trial_count.unique() + + # pad ends with 1 second extra to remove edge effects from convolution + scan_pad = self.SAMPLE_RATE + scan_starts = start_t - scan_pad + scan_ends = end_t + scan_pad + scan_durations = scan_ends - scan_starts + + cursor = 0 # In sample points + rec_trials = {} + trial_positions = {} + bin_frs = {} + bin_counts = {} + + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + # TODO jun 12 2024 skip other streams for now + if stream_num > 0: + continue + + # Account for multiple raw data files + meta = self.ap_meta[stream_num] + samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 + assert samples.is_integer() + in_SAMPLE_RATE_scale = (samples * self.SAMPLE_RATE)\ + / int(self.ap_meta[0]['imSampRate']) + cursor_duration = (cursor * self.SAMPLE_RATE)\ + / int(self.ap_meta[0]['imSampRate']) + rec_spikes = spikes[ + (cursor_duration <= spikes)\ + & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) + ] - cursor_duration + cursor += samples + + # Account for lag, in case the ephys recording was started before the + # behaviour + if not self._lag[stream_num] == None: + lag_start, _ = self._lag[stream_num] + else: + lag_start = timestamps[0] + + if lag_start < 0: + rec_spikes = rec_spikes + lag_start + + for i, start in enumerate(selected_starts): + # select spike times of current trial + trial_bool = (rec_spikes >= scan_starts[i])\ + & (rec_spikes <= scan_ends[i]) + trial = rec_spikes[trial_bool] + + # get position bin ids for current trial + trial_pos_bool = (positions.index >= start_t[i])\ + & (positions.index < end_t[i]) + trial_pos = positions[trial_pos_bool] + + # initiate binary spike times array for current trial + # NOTE: dtype must be float otherwise would get all 0 when + # passing gaussian kernel + times = np.zeros((scan_durations[i], len(units))).astype(float) + # use pixels time as spike index + idx = np.arange(scan_starts[i], scan_ends[i]) + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) + + for unit in trial: + # get spike time for unit + u_times = trial[unit].values + # drop nan + u_times = u_times[~np.isnan(u_times)] + + # round spike times to use it as index + u_spike_idx = np.round(u_times).astype(int) + # make sure it does not exceed scan duration + if (u_spike_idx >= scan_ends[i]).any(): + beyonds = np.where(u_spike_idx >= scan_ends[i])[0] + u_spike_idx[beyonds] = idx[-1] + # make sure no double counted + u_spike_idx = np.unique(u_spike_idx) + + # set spiked to 1 + spiked.loc[u_spike_idx, unit] = 1 + + # convolve spike trains into spike rates + rates = signal.convolve_spike_trains( + times=spiked, + sigma=sigma, + sample_rate=self.SAMPLE_RATE, + ) + # remove 1s padding from the start and end + rates = rates.iloc[scan_pad: -scan_pad] + spiked = spiked.iloc[scan_pad: -scan_pad] + # reset index to zero at the beginning of the trial + rates.reset_index(inplace=True, drop=True) + spiked.reset_index(inplace=True, drop=True) + trial_pos.reset_index(inplace=True, drop=True) + + # get bin firing rates + bin_frs[i] = self.bin_vr_trial( + data=rates, + positions=trial_pos, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", + ) + # get bin spike count + bin_counts[i] = self.bin_vr_trial( + data=spiked, + positions=trial_pos, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", + ) + + # stack df values into np array + # reshape into trials x units x bins + bin_count_arr = ioutils.reindex_by_longest(bin_counts).T + bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T + + # save bin_fr and bin_count, for alfredo & andrew + # use label as array key name + fr_to_save = { + "fr": bin_fr_arr[:, :-2, :], + "pos": bin_fr_arr[:, -2:, :], + } + np.savez_compressed(output_fr_path, **fr_to_save) + print(f"> Output saved at {output_fr_path}.") + + count_to_save = { + "count": bin_count_arr[:, :-2, :], + "pos": bin_count_arr[:, -2:, :], + } + np.savez_compressed(output_count_path, **count_to_save) + print(f"> Output saved at {output_count_path}.") + + return None From d2ed1cb7650e3ec29d01002caaa01edcb5d0353a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 28 Feb 2025 12:28:44 +0000 Subject: [PATCH 215/658] rename so that ranger does not crash --- pixels/{signal.py => signal_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pixels/{signal.py => signal_utils.py} (100%) diff --git a/pixels/signal.py b/pixels/signal_utils.py similarity index 100% rename from pixels/signal.py rename to pixels/signal_utils.py From 849b2136b0f1a4ea8c6cf14a85b92581f31d06f2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 28 Feb 2025 12:30:17 +0000 Subject: [PATCH 216/658] rename signal.py so that ranger does not crash --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4babab2..dacd39b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -36,7 +36,7 @@ from wavpack_numcodecs import WavPack from pixels import ioutils -from pixels import signal +import pixels.signal_utils as signal import pixels.pixels_utils as xut from pixels.error import PixelsError from pixels.constants import * From b7ad653539610c19ae4faec14033f2a06955f77c Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:32:43 +0000 Subject: [PATCH 217/658] allows input array to be np --- pixels/signal_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/pixels/signal_utils.py b/pixels/signal_utils.py index 991b15e..a585f87 100644 --- a/pixels/signal_utils.py +++ b/pixels/signal_utils.py @@ -327,18 +327,29 @@ def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): # normalise kernel to ensure that the total area under the Gaussian is 1 n_kernel = kernel / np.sum(kernel) - # convolve with gaussian - convolved = convolve1d( - input=times.values, - weights=n_kernel, - output=float, - mode='nearest', - axis=0, - ) * sample_rate # rescale it to second + if isinstance(times, pd.DataFrame): + # convolve with gaussian + convolved = convolve1d( + input=times.values, + weights=n_kernel, + output=float, + mode='nearest', + axis=0, + ) * sample_rate # rescale it to second + + output = pd.DataFrame(convolved, columns=times.columns) + + elif isinstance(times, np.ndarray): + # convolve with gaussian + output = convolve1d( + input=times, + weights=n_kernel, + output=float, + mode='nearest', + axis=0, + ) * sample_rate # rescale it to second - df = pd.DataFrame(convolved, columns=times.columns) - - return df + return output def convolve(times, duration, sigma=None): From 303d86f1b251f7eca22e4a93b6451e02ebc70c8a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:33:01 +0000 Subject: [PATCH 218/658] change signal module name --- pixels/behaviours/virtual_reality.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 8a005ff..7bf3fae 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -18,7 +18,8 @@ from vision_in_darkness.base import Outcomes, Worlds, Conditions from pixels import Experiment, PixelsError -from pixels import signal, ioutils +import pixels.signal_utils as signal +from pixels import ioutils from pixels.behaviours import Behaviour from common_utils import file_utils From 25b19c2bb4caee295708108f89b15a69a42896b4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:36:15 +0000 Subject: [PATCH 219/658] takes data type arg to save trial_times --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index dacd39b..d2df885 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1741,7 +1741,7 @@ def _get_aligned_spike_times( def _get_aligned_trials( - self, label, event, units=None, sigma=None, end_event=None, + self, label, event, data, units=None, sigma=None, end_event=None, ): """ Returns spike rate for each unit within a trial. From d2f26ab2cb98da5a65fd0877f55af1e42936cf43 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:37:09 +0000 Subject: [PATCH 220/658] setup vr behaviour vars --- pixels/behaviours/base.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d2df885..caafc82 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1754,22 +1754,16 @@ def _get_aligned_trials( """ action_labels = self.get_action_labels()[0] - # define output path for binned spike rate - output_fr_path = self.interim/\ - f'cache/{self.name}_{label}_{units}_{time_bin}_spike_rate.npz' - output_count_path = self.interim/\ - f'cache/{self.name}_{label}_{units}_{time_bin}_spike_count.npz' - if units is None: units = self.select_units() - if not pos_bin is None: - behaviour_files = self.files["behaviour"] - # assume only one vr session for now - vr_dir = self.find_file(behaviour_files["vr_synched"][0]) - vr_data = ioutils.read_hdf5(vr_dir) - # get positions - positions = vr_data.position_in_tunnel + #if not pos_bin is None: + behaviour_files = self.files["behaviour"] + # assume only one vr session for now + vr_dir = self.find_file(behaviour_files["vr_synched"][0]) + vr_data = ioutils.read_hdf5(vr_dir) + # get positions + positions = vr_data.position_in_tunnel #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! From a31cca7ddc84f6d8382c29171692aec023faf730 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:37:43 +0000 Subject: [PATCH 221/658] also save spiked boolean --- pixels/behaviours/base.py | 117 +++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 57 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index caafc82..95105d8 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1803,16 +1803,17 @@ def _get_aligned_trials( scan_durations = scan_ends - scan_starts cursor = 0 # In sample points - rec_trials = {} + rec_trials_fr = {} + rec_trials_spiked = {} trial_positions = {} bin_frs = {} bin_counts = {} streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - # TODO jun 12 2024 skip other streams for now - if stream_num > 0: - continue + # allows multiple streams of recording, i.e., multiple probes + rec_trials_fr[stream_id] = {} + rec_trials_spiked[stream_id] = {} # Account for multiple raw data files meta = self.ap_meta[stream_num] @@ -1887,65 +1888,67 @@ def _get_aligned_trials( spiked = spiked.iloc[scan_pad: -scan_pad] # reset index to zero at the beginning of the trial rates.reset_index(inplace=True, drop=True) + rec_trials_fr[stream_id][trial_ids[i]] = rates spiked.reset_index(inplace=True, drop=True) + rec_trials_spiked[stream_id][trial_ids[i]] = spiked trial_pos.reset_index(inplace=True, drop=True) - - rec_trials[trial_ids[i]] = rates trial_positions[trial_ids[i]] = trial_pos - ## get bin firing rates - #bin_frs[i] = self.bin_vr_trial( - # data=rates, - # positions=trial_pos, - # time_bin=time_bin, - # pos_bin=pos_bin, - # bin_method="mean", - #) - ## get bin spike count - #bin_counts[i] = self.bin_vr_trial( - # data=spiked, - # positions=trial_pos, - # time_bin=time_bin, - # pos_bin=pos_bin, - # bin_method="sum", - #) - - ## stack df values into np array - ## reshape into trials x units x bins - #bin_count_arr = ioutils.reindex_by_longest(bin_counts).T - #bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T - - ## save bin_fr and bin_count, for alfredo & andrew - ## use label as array key name - #fr_to_save = { - # "fr": bin_fr_arr[:, :-2, :], - # "pos": bin_fr_arr[:, -2:, :], - #} - #np.savez_compressed(output_fr_path, **fr_to_save) - #print(f"> Output saved at {output_fr_path}.") - #count_to_save = { - # "count": bin_count_arr[:, :-2, :], - # "pos": bin_count_arr[:, -2:, :], - #} - #np.savez_compressed(output_count_path, **count_to_save) - #print(f"> Output saved at {output_count_path}.") + #if not rec_trials_fr[stream_id]: + # return None + + # TODO feb 28 2025: + # in this func, scanning period is longer than actual trials to avoid + # edging effect, so spiked is the whole scanning period then convolved + # into spike rate. if however, we get the spike_bool here, it does not + # have the scanning period buffer, the fr might have edging effect. but + # if we take the scanning period, the number of spikes will be + # different from the real data... + # SOLUTION: concat as below, shuffle per column, then convolve per + # column + # TODO mar 7 2025: + # CONTINUE HERE! + if data == "trial_times": + spiked = pd.concat( + rec_trials_spiked[stream_id], + axis=0, + ) + assert 0 - if not rec_trials: - return None + continue + s_chance_path = self.interim/stream_files["spiked_shuffled_memmap"] + fr_chance_path = self.interim/stream_files["fr_shuffled_memmap"] + chance_df_path = self.processed/stream_files["shuffled"] - # concat trial df - positions = ioutils.reindex_by_longest( - dfs=trial_positions, - return_format="dataframe", - names="trial", - ) - fr = ioutils.reindex_by_longest( - dfs=rec_trials, - return_format="dataframe", - names=["trial", "unit"], - ) - fr = fr.reorder_levels(["unit", "trial"], axis=1) - fr = fr.sort_index(level=0, axis=1) + chance_data = xut.get_spike_chance( + spiked=pd.concat(rec_trials_spiked[stream_id], axis=0), + sigma=sigma, + sample_rate=self.SAMPLE_RATE, + spiked_chance_path=s_chance_path, + fr_chance_path=fr_chance_path, + chance_df_path=chance_df_path, + ) + assert 0 + # concat trial df + positions = ioutils.reindex_by_longest( + dfs=trial_positions, + return_format="dataframe", + names="trial", + ) + + fr = ioutils.reindex_by_longest( + dfs=rec_trials_fr[stream_id], + return_format="dataframe", + names=["trial", "unit"], + ) + fr = fr.reorder_levels(["unit", "trial"], axis=1) + fr = fr.sort_index(level=0, axis=1) + + spiked = ioutils.reindex_by_longest( + dfs=rec_trials_spiked[stream_id], + return_format="dataframe", + names=["trial", "unit"], + ) return {"fr": fr, "positions": positions} From 736126475aeefd0f87f0391bf63f077556c19348 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:38:34 +0000 Subject: [PATCH 222/658] allows to take all units just based on name --- pixels/behaviours/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 95105d8..d6933a0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2010,6 +2010,10 @@ def select_units( if name is not None: selected_units.name = name + if name == "all": + selected_units.extend(unit_ids) + return selected_units + # get shank id for units shank_ids = sa.sorting.get_property("group") From eecedd9f6b7338f7080b5e7826858b0bd30711f0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:39:53 +0000 Subject: [PATCH 223/658] allows trial times --- pixels/behaviours/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d6933a0..e28ae90 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2198,6 +2198,8 @@ def align_trials( 'motion_tracking', # Motion tracking coordinates from DLC 'trial_rate', # Taking spike times from the whole duration of each # trial, convolve into spike rate + 'trial_times', # Taking spike times from the whole duration of each + # trial, get spike boolean ] if data not in data_options: raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") @@ -2210,11 +2212,11 @@ def align_trials( units=units, ) - if data == "trial_rate": + if "trial" in data: print(f"Aligning {data} of {units} units to trials.") - # we let a dedicated function handle aligning spike times return self._get_aligned_trials( - label, event, units=units, sigma=sigma, end_event=end_event, + label, event, data=data, units=units, sigma=sigma, + end_event=end_event, ) if data == "motion_tracking" and not dlc_project: From c10cf4df34319f9d1e4ac54a7f0dce4604ca1ee3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:40:32 +0000 Subject: [PATCH 224/658] add todo --- pixels/behaviours/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e28ae90..be3e3bf 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1859,6 +1859,10 @@ def _get_aligned_trials( # make it df, column name being unit id spiked = pd.DataFrame(times, index=idx, columns=units) + # TODO mar 5 2025: + # how to separate aligned trial times and chance, so that i can + # use cacheable to get all conditions?????? + for unit in trial: # get spike time for unit u_times = trial[unit].values From b77d12666cac6038665ecf6bce20007debe2ba5a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:40:43 +0000 Subject: [PATCH 225/658] add func for chance data --- pixels/behaviours/base.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index be3e3bf..2135a31 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3249,3 +3249,34 @@ def bin_aligned_trials( print(f"> Output saved at {output_count_path}.") return None + + + @_cacheable + def get_spike_chance( + self, label, event, end_event=None, sigma=None, units=None, + ): + # get aligned firing rates and positions + trials = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + data="trial_times", + sigma=sigma, + end_event=end_event, + ) + + assert 0 + # TODO mar 5 2025: + # how to get spike times then get chance? and cache them? + + # get chance data + chance_data = xut.get_spike_chance( + spiked=pd.concat(rec_trials_spiked[stream_id], axis=0), + sigma=sigma, + sample_rate=self.SAMPLE_RATE, + spiked_chance_path=s_chance_path, + fr_chance_path=fr_chance_path, + chance_df_path=chance_df_path, + ) + + return chance_data From a3f3ececa56228bf149d9a230fb900510292473a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:40:58 +0000 Subject: [PATCH 226/658] add imports --- pixels/pixels_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 6408933..e583dc0 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1,3 +1,5 @@ +import multiprocessing as mp + import numpy as np import pandas as pd @@ -10,6 +12,11 @@ import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm +import pixels.signal_utils as signal + +from common_utils.math_utils import random_sampling +from common_utils.file_utils import init_memmap, read_hdf5 + # set si job_kwargs job_kwargs = dict( n_jobs=0.8, # 80% core From e1f91edac1f03b945c770b5e09680a9cc79cedba Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:41:05 +0000 Subject: [PATCH 227/658] initiate random number generator --- pixels/pixels_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index e583dc0..2ebaeea 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -25,6 +25,9 @@ ) si.set_global_job_kwargs(**job_kwargs) +# initiate random number generator +rng = np.random.default_rng() + def load_raw(paths, stream_id): """ Load raw recording file from spikeglx. From 3be66e64ec26c7a8803106967458b565def29d14 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:41:27 +0000 Subject: [PATCH 228/658] permute spiked boolean repeatedly with parallel processing to create chance data --- pixels/pixels_utils.py | 263 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2ebaeea..a99ce47 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -532,3 +532,266 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): ) return None + + +def _permute_spikes_n_convolve_fr(array, sigma, sample_rate): + """ + Randomly permute spike boolean across time. + + params + === + array: 2D np array, time points x units. + + sigma: int/float, time in millisecond of sigma of gaussian kernel for firing + rate convolution. + + sample_rate: float/int, sampling rate of signal. + + return + === + random_spiked: shuffled spike boolean for each unit. + + random_fr: convolved firing rate from shuffled spike boolean for each unit. + """ + # permutate columns + random_spiked = rng.permuted(array, axis=0) + # convolve into firing rate + random_fr = signal.convolve_spike_trains( + times=random_spiked, + sigma=sigma, + sample_rate=sample_rate, + ) + + return random_spiked, random_fr + + +def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, + spiked_chance_path, fr_chance_path): + """ + Worker that computes one set of spiked and fr values. + + params + === + i: index of current repeat. + + sigma: int/float, time in millisecond of sigma of gaussian kernel for firing + rate convolution. + + sample_rate: float/int, sampling rate of signal. + + spiked_shape: tuple, shape of spike boolean to initiate memmap. + + chance_data_shape: tuple, shape of chance data. + + spiked_chance_path: + + fr_chance_path: + + return + === + """ + print(f"Processing repeat {i}...") + # open readonly memmap + spiked = init_memmap( + path=spiked_chance_path.parent/"temp_spiked.bin", + shape=spiked_shape, + dtype=np.int16, + overwrite=False, + readonly=True, + ) + + # init appendable memmap + chance_spiked = init_memmap( + path=spiked_chance_path, + shape=chance_data_shape, + dtype=np.int16, + overwrite=False, + readonly=False, + ) + # init chance firing rate memmap + chance_fr = init_memmap( + path=fr_chance_path, + shape=chance_data_shape, + dtype=np.float32, + overwrite=False, + readonly=False, + ) + + # get permuted data + c_spiked, c_fr = _permute_spikes_n_convolve_fr(spiked, sigma, sample_rate) + + chance_spiked[..., i] = c_spiked + chance_fr[..., i] = c_fr + # write to disk + chance_spiked.flush() + chance_fr.flush() + + print(f"Repeat {i} finished.") + + return None + + +def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, + fr_chance_path, chance_df_path, repeats=100): + """ + Implementation of getting chance level spike data. + """ + import concurrent.futures + + # get export data shape + spiked_shape = spiked.shape + d_shape = spiked.shape + (repeats,) + + if not chance_fr.exists(): + spiked_memmap = init_memmap( + path=spiked_chance_path.parent/"temp_spiked.bin", + shape=spiked.shape, + dtype=np.int16, + overwrite=True, + readonly=False, + ) + spiked_memmap[:] = spiked.values + spiked_memmap.flush() + del spiked_memmap + + # init chance spiked memmap + chance_spiked = init_memmap( + path=spiked_chance_path, + shape=d_shape, + dtype=np.int16, + overwrite=True, + readonly=False, + ) + # init chance firing rate memmap + chance_fr = init_memmap( + path=fr_chance_path, + shape=d_shape, + dtype=np.float32, + overwrite=True, + readonly=False, + ) + # write to disk + chance_spiked.flush() + chance_fr.flush() + del chance_spiked, chance_fr + + # Set up the process pool to run the worker in parallel. + with concurrent.futures.ProcessPoolExecutor() as executor: + # Submit jobs for each repeat. + futures = [] + for i in range(repeats): + future = executor.submit( + _chance_worker, + i=i, + sigma=sigma, + sample_rate=sample_rate, + spiked_shape=spiked_shape, + chance_data_shape=d_shape, + spiked_chance_path=spiked_chance_path, + fr_chance_path=fr_chance_path, + ) + futures.append(future) + + # As each future completes, assign the results into the memmap. + for future in concurrent.futures.as_completed(futures): + future.result() + else: + print("\n> Memmaps already created, only need to convert into " + "dataframes.") + + # convert it to dataframe and save it + chance_data = convert_to_df( + spiked, + spiked_chance_path, + fr_chance_path, + chance_df_path, + d_shape, + ) + + return chance_data + + +def _convert_to_df(spiked, memmap_path, df_path, d_shape): + # init readonly chance memmap + chance_memmap = init_memmap( + path=memmap_path, + shape=d_shape, + dtype=np.int16, + overwrite=False, + readonly=True, + ) + + # copy to cpu + c_spiked = chance_memmap.copy() + # reshape to 2D + c_spiked_reshaped = c_spiked.reshape(d_shape[0], d_shape[1] * d_shape[2]) + # create hierarchical index + col_idx = pd.MultiIndex.from_product( + [spiked.columns, np.arange(repeats)], + names=["unit", "repeat"], + ) + + # create df + df = pd.DataFrame(c_spiked_reshaped, columns=col_idx) + # use the original index + df.index = spiked.index + # name index + df.index.names = ["trial", "time"] + + return df + + +def convert_to_df( + spiked, spiked_chance_path, fr_chance_path, chance_df_path, d_shape, +): + + # init readonly chance memmap + chance_spiked = init_memmap( + path=spiked_chance_path, + shape=d_shape, + dtype=np.int16, + overwrite=False, + readonly=True, + ) + chance_fr = init_memmap( + path=fr_chance_path, + shape=d_shape, + dtype=np.float32, + overwrite=False, + readonly=True, + ) + + # get trial ids + trial_ids = spiked.index.get_level_values(0).unique() + + assert 0 + # copy to cpu + c_spiked = chance_spiked.copy() + # reshape to 2D + c_spiked_reshaped = c_spiked.reshape(d_shape[0], d_shape[1] * d_shape[2]) + # create hierarchical index + col_idx = pd.MultiIndex.from_product( + [spiked.columns, np.arange(repeats)], + names=["unit", "repeat"], + ) + # create df + chance_spiked_df = pd.DataFrame(c_spiked_reshaped, columns=col_idx) + # use the original index + chance_spiked_df.index = spiked.index + # name index + chance_spiked_df.index.names = ["trial", "time"] + assert 0 + + return {"spiked": chance_spiked_df, "fr": chance_fr_df} + + +def get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, + fr_memmap_path, chance_df_path, repeats=100): + if chance_df_path.exists(): + # read and return + return read_hdf5(chance_df_path) + else: + chance_data = _get_spike_chance( + spiked, sigma, sample_rate, spiked_memmap_path, fr_memmap_path, + chance_df_path, repeats) + return chance_data From 4c828d6d4f94cde4fc331f6abaf61b65a43840e3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 7 Mar 2025 17:42:03 +0000 Subject: [PATCH 229/658] add chance data but todo --- pixels/ioutils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 5574db1..903642b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -119,6 +119,24 @@ def get_data_files(data_dir, session_name): f"{session_name}_{stream_id}_channel_clustering_results.h5" ) + # TODO mar 5 2025: + # maybe do NOT put shuffled data in here, cuz there will be different + # trial conditions, better to cache them??? + + # shuffled response for each unit, in light & dark conditions, to get + # the chance + # memmaps for temporary storage + pixels[stream_id]["spiked_shuffled_memmap"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_spiked_shuffled.bin" + ) + pixels[stream_id]["fr_shuffled_memmap"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_fr_shuffled.bin" + ) + # .h5 files + pixels[stream_id]["shuffled"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_shuffled.h5" + ) + # old catgt data pixels[stream_id]["CatGT_ap_data"].append( str(base_name).replace("t0", "tcat") From d132c7942f653dad2223921651b10c0d63e22119 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 31 Mar 2025 18:59:02 +0100 Subject: [PATCH 230/658] use more informative name of the context; allows defining dtype --- pixels/pixels_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index a99ce47..1d9ff4a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -700,23 +700,24 @@ def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, "dataframes.") # convert it to dataframe and save it - chance_data = convert_to_df( - spiked, - spiked_chance_path, - fr_chance_path, - chance_df_path, - d_shape, + # TODO mar 31 2025: how to save it??? + chance_data = compile_chance( + original_idx=spiked.index, + spiked_chance_path=spiked_chance_path, + fr_chance_path=fr_chance_path, + chance_df_path=chance_df_path, + d_shape=d_shape, ) return chance_data -def _convert_to_df(spiked, memmap_path, df_path, d_shape): +def _convert_to_df(original_idx, memmap_path, df_path, d_shape, d_type, name): # init readonly chance memmap chance_memmap = init_memmap( path=memmap_path, shape=d_shape, - dtype=np.int16, + dtype=d_type, overwrite=False, readonly=True, ) From 1bcfeff12c3baee0bff1ab8143af1e7854f7c5c7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 31 Mar 2025 18:59:49 +0100 Subject: [PATCH 231/658] only needs original index; use _convert_to_df and save df to disk in that func --- pixels/pixels_utils.py | 70 +++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1d9ff4a..4ecbd87 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -735,55 +735,47 @@ def _convert_to_df(original_idx, memmap_path, df_path, d_shape, d_type, name): # create df df = pd.DataFrame(c_spiked_reshaped, columns=col_idx) # use the original index - df.index = spiked.index + df.index = original_idx # name index df.index.names = ["trial", "time"] - return df + # write h5 to disk + ioutils.write_hdf5( + path=df_path, + df=df, + key=name, + mode="a", + ) + return None -def convert_to_df( - spiked, spiked_chance_path, fr_chance_path, chance_df_path, d_shape, -): - # init readonly chance memmap - chance_spiked = init_memmap( - path=spiked_chance_path, - shape=d_shape, - dtype=np.int16, - overwrite=False, - readonly=True, +def compile_chance( + original_idx, spiked_chance_path, fr_chance_path, chance_df_path, d_shape, +): + # TODO mar 31 2025: test _convert_to_df + # get chance spiked df + chance_spiked_df = _convert_to_df( + original_idx=original_idx, + memmap_path=spiked_chance_path, + df_path=chance_df_path, + d_shape=d_shape, + d_type=np.int16, + name="spiked", ) - chance_fr = init_memmap( - path=fr_chance_path, - shape=d_shape, - dtype=np.float32, - overwrite=False, - readonly=True, + # get chance fr df + chance_fr_df = _convert_to_df( + original_idx=original_idx, + memmap_path=fr_chance_path, + df_path=chance_df_path, + d_shape=d_shape, + d_type=np.float32, + name="fr", ) - # get trial ids - trial_ids = spiked.index.get_level_values(0).unique() - assert 0 - # copy to cpu - c_spiked = chance_spiked.copy() - # reshape to 2D - c_spiked_reshaped = c_spiked.reshape(d_shape[0], d_shape[1] * d_shape[2]) - # create hierarchical index - col_idx = pd.MultiIndex.from_product( - [spiked.columns, np.arange(repeats)], - names=["unit", "repeat"], - ) - # create df - chance_spiked_df = pd.DataFrame(c_spiked_reshaped, columns=col_idx) - # use the original index - chance_spiked_df.index = spiked.index - # name index - chance_spiked_df.index.names = ["trial", "time"] - assert 0 - - return {"spiked": chance_spiked_df, "fr": chance_fr_df} + # TODO mar 31 2025: does it work or does it give memory error? + return None def get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, From 12398b7a863f90cf44bce5626f785fa1d73f3b76 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 31 Mar 2025 19:00:39 +0100 Subject: [PATCH 232/658] use cached spiked & fr for binning --- pixels/behaviours/base.py | 59 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2135a31..e594fad 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3058,6 +3058,65 @@ def bin_aligned_trials( For VR behaviour, size of positional bin for position data. """ + # TODO mar 31 2025: + # use cached get_aligned_trials to bin so that we do not need to + # duplicate code + + bin_frs = {} + bin_counts = {} + + # get aligned spiked and positions + spiked = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + data="trial_times", + sigma=sigma, + end_event=end_event, + ) + fr = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + data="trial_rate", + sigma=sigma, + end_event=end_event, + ) + + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + # get stream spiked + stream_spiked = spiked[stream_id]["spiked"] + # get stream positions + positions = spiked[stream_id["positions"] + # get stream firing rates + stream_fr = fr[stream_id]["fr"] + + bin_frs[stream_id] = {} + bin_counts[stream_id] = {} + for trial in positions.columns.unique(): + counts = stream_spiked.xs(trial, level="trial", axis=1) + rates = stream_fr.xs(trial, level="trial", axis=1) + trial_pos = positions[trial] + + # get bin spike count + bin_counts[stream_id][trial] = self.bin_vr_trial( + data=counts, + positions=trial_pos, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", + ) + # get bin firing rates + bin_frs[stream_id][trial] = self.bin_vr_trial( + data=rates, + positions=trial_pos, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", + ) + assert 0 + action_labels = self.get_action_labels()[0] # define output path for binned spike rate From 1dac345dc3622faf3b2ff42b0caba7742b459b5f Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:26:18 +0100 Subject: [PATCH 233/658] add import --- pixels/pixels_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4ecbd87..5c0137c 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -13,6 +13,7 @@ import spikeinterface.qualitymetrics as sqm import pixels.signal_utils as signal +from pixels.ioutils import write_hdf5 from common_utils.math_utils import random_sampling from common_utils.file_utils import init_memmap, read_hdf5 From 9282e1e13aa4fe561cb98f6d60e2d918e2e76f60 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:27:14 +0100 Subject: [PATCH 234/658] use more accurate name for memmap --- pixels/pixels_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 5c0137c..2313333 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -567,7 +567,7 @@ def _permute_spikes_n_convolve_fr(array, sigma, sample_rate): def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, - spiked_chance_path, fr_chance_path): + spiked_memmap_path, fr_memmap_path): """ Worker that computes one set of spiked and fr values. @@ -584,9 +584,9 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, chance_data_shape: tuple, shape of chance data. - spiked_chance_path: + spiked_memmap_path: - fr_chance_path: + fr_memmap_path: return === @@ -594,7 +594,7 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, print(f"Processing repeat {i}...") # open readonly memmap spiked = init_memmap( - path=spiked_chance_path.parent/"temp_spiked.bin", + path=spiked_memmap_path.parent/"temp_spiked.bin", shape=spiked_shape, dtype=np.int16, overwrite=False, @@ -603,7 +603,7 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, # init appendable memmap chance_spiked = init_memmap( - path=spiked_chance_path, + path=spiked_memmap_path, shape=chance_data_shape, dtype=np.int16, overwrite=False, @@ -611,7 +611,7 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, ) # init chance firing rate memmap chance_fr = init_memmap( - path=fr_chance_path, + path=fr_memmap_path, shape=chance_data_shape, dtype=np.float32, overwrite=False, @@ -643,9 +643,9 @@ def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, spiked_shape = spiked.shape d_shape = spiked.shape + (repeats,) - if not chance_fr.exists(): + if not fr_memmap_path.exists(): spiked_memmap = init_memmap( - path=spiked_chance_path.parent/"temp_spiked.bin", + path=spiked_memmap_path.parent/"temp_spiked.bin", shape=spiked.shape, dtype=np.int16, overwrite=True, @@ -657,7 +657,7 @@ def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, # init chance spiked memmap chance_spiked = init_memmap( - path=spiked_chance_path, + path=spiked_memmap_path, shape=d_shape, dtype=np.int16, overwrite=True, @@ -665,7 +665,7 @@ def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, ) # init chance firing rate memmap chance_fr = init_memmap( - path=fr_chance_path, + path=fr_memmap_path, shape=d_shape, dtype=np.float32, overwrite=True, @@ -688,8 +688,8 @@ def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, sample_rate=sample_rate, spiked_shape=spiked_shape, chance_data_shape=d_shape, - spiked_chance_path=spiked_chance_path, - fr_chance_path=fr_chance_path, + spiked_memmap_path=spiked_memmap_path, + fr_memmap_path=fr_memmap_path, ) futures.append(future) From 6dfbc67ec88b3a05b0502b1bd4d09b78349ae56a Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:27:37 +0100 Subject: [PATCH 235/658] add higher level func for save_spike_chance to avoid running again if already saved --- pixels/pixels_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2313333..dbe99b2 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -632,10 +632,23 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, return None -def _get_spike_chance(spiked, sigma, sample_rate, spiked_chance_path, - fr_chance_path, chance_df_path, repeats=100): +def save_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, + fr_memmap_path, chance_df_path, repeats=100): + if not chance_df_path.exists(): + # save spike chance data if does not exists + _save_spike_chance( + spiked, sigma, sample_rate, spiked_memmap_path, fr_memmap_path, + chance_df_path, repeats) + else: + print(f"> Spike chance already saved at {chance_df_path}, continue.") + + return None + + +def _save_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, + fr_memmap_path, chance_df_path, repeats=100): """ - Implementation of getting chance level spike data. + Implementation of saving chance level spike data. """ import concurrent.futures From 5dda205254083bd098d0e79dcc5a73b660096cf2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:29:33 +0100 Subject: [PATCH 236/658] use _convert_to_df; set `repeat` as the outer most level --- pixels/pixels_utils.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index dbe99b2..1640cfd 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -740,47 +740,48 @@ def _convert_to_df(original_idx, memmap_path, df_path, d_shape, d_type, name): c_spiked = chance_memmap.copy() # reshape to 2D c_spiked_reshaped = c_spiked.reshape(d_shape[0], d_shape[1] * d_shape[2]) + # create hierarchical index col_idx = pd.MultiIndex.from_product( - [spiked.columns, np.arange(repeats)], - names=["unit", "repeat"], + [np.arange(d_shape[2]), orig_col], + names=["repeat", "unit"], ) # create df df = pd.DataFrame(c_spiked_reshaped, columns=col_idx) # use the original index - df.index = original_idx - # name index - df.index.names = ["trial", "time"] + df.index = orig_idx # write h5 to disk - ioutils.write_hdf5( + write_hdf5( path=df_path, df=df, key=name, mode="a", ) + assert 0 return None -def compile_chance( - original_idx, spiked_chance_path, fr_chance_path, chance_df_path, d_shape, -): +def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, + chance_df_path, d_shape): # TODO mar 31 2025: test _convert_to_df # get chance spiked df - chance_spiked_df = _convert_to_df( - original_idx=original_idx, - memmap_path=spiked_chance_path, + _convert_to_df( + orig_idx=orig_idx, + orig_col=orig_col, + memmap_path=spiked_memmap_path, df_path=chance_df_path, d_shape=d_shape, d_type=np.int16, name="spiked", ) # get chance fr df - chance_fr_df = _convert_to_df( - original_idx=original_idx, - memmap_path=fr_chance_path, + _convert_to_df( + orig_idx=orig_idx, + orig_col=orig_col, + memmap_path=fr_memmap_path, df_path=chance_df_path, d_shape=d_shape, d_type=np.float32, From 825a9380c718a584fff60218bc3d63abe9914cfb Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:30:10 +0100 Subject: [PATCH 237/658] if chance data not found, save it first --- pixels/pixels_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1640cfd..95ffe31 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -795,11 +795,10 @@ def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, def get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, fr_memmap_path, chance_df_path, repeats=100): - if chance_df_path.exists(): - # read and return - return read_hdf5(chance_df_path) - else: - chance_data = _get_spike_chance( + if not chance_df_path.exists(): + # save spike chance data if does not exists + save_spike_chance( spiked, sigma, sample_rate, spiked_memmap_path, fr_memmap_path, chance_df_path, repeats) - return chance_data + + return read_hdf5(chance_df_path) From 0104ad7a666bdbb48eb7fbc36828d0e89afacb1a Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:30:52 +0100 Subject: [PATCH 238/658] add documentation; change name n args --- pixels/pixels_utils.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 95ffe31..3ad0e56 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -714,19 +714,40 @@ def _save_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, "dataframes.") # convert it to dataframe and save it - # TODO mar 31 2025: how to save it??? - chance_data = compile_chance( - original_idx=spiked.index, - spiked_chance_path=spiked_chance_path, - fr_chance_path=fr_chance_path, + save_chance( + orig_idx=spiked.index, + orig_col=spiked.columns, + spiked_memmap_path=spiked_memmap_path, + fr_memmap_path=fr_memmap_path, chance_df_path=chance_df_path, d_shape=d_shape, ) - return chance_data + print(f"\n> Chance data saved to {chance_df_path}.") + return None + + +def _convert_to_df(orig_idx, orig_col, memmap_path, df_path, d_shape, d_type, + name): + """ + Convert + + orig_idx, + orig_col, + memmap_path + df_path, + d_shape + d_type + name + """ + # NOTE: shape of memmap is `concatenated trials frames * units * repeats`, + # saved df has outer most level being `repeat`, then `unit`, and all trials + # are stacked vertically. + # to later use it for analysis, go into each repeat, and do + # `repeat_df.unstack(level='trial', sort=False)` to get the same structure as + # data. -def _convert_to_df(original_idx, memmap_path, df_path, d_shape, d_type, name): # init readonly chance memmap chance_memmap = init_memmap( path=memmap_path, From 65cd4c89eb28b6206318369f7f9dfd4272e55d01 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:31:49 +0100 Subject: [PATCH 239/658] get spike times and rate for each stream and concat streams into df --- pixels/behaviours/base.py | 103 +++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 35 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e594fad..4123513 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1803,11 +1803,10 @@ def _get_aligned_trials( scan_durations = scan_ends - scan_starts cursor = 0 # In sample points + output = {} rec_trials_fr = {} rec_trials_spiked = {} trial_positions = {} - bin_frs = {} - bin_counts = {} streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): @@ -1910,51 +1909,85 @@ def _get_aligned_trials( # different from the real data... # SOLUTION: concat as below, shuffle per column, then convolve per # column - # TODO mar 7 2025: - # CONTINUE HERE! + + # concat trial positions + positions = ioutils.reindex_by_longest( + dfs=trial_positions, + return_format="dataframe", + names="trial", + ) + if data == "trial_times": - spiked = pd.concat( + # get trials vertically stacked spiked + stacked_spiked = pd.concat( rec_trials_spiked[stream_id], axis=0, ) - assert 0 - - continue - s_chance_path = self.interim/stream_files["spiked_shuffled_memmap"] - fr_chance_path = self.interim/stream_files["fr_shuffled_memmap"] - chance_df_path = self.processed/stream_files["shuffled"] - - chance_data = xut.get_spike_chance( - spiked=pd.concat(rec_trials_spiked[stream_id], axis=0), + stacked_spiked.index.names = ["trial", "time"] + stacked_spiked.columns.names = ["unit"] + + # get chance data paths + s_memmap_path = self.interim /\ + stream_files["spiked_shuffled_memmap"] + fr_memmap_path = self.interim /\ + stream_files["fr_shuffled_memmap"] + chance_df_path = self.processed / stream_files["shuffled"] + + # save chance data + xut.save_spike_chance( + spiked=stacked_spiked, sigma=sigma, sample_rate=self.SAMPLE_RATE, - spiked_chance_path=s_chance_path, - fr_chance_path=fr_chance_path, + spiked_memmap_path=s_memmap_path, + fr_memmap_path=fr_memmap_path, chance_df_path=chance_df_path, ) assert 0 - # concat trial df - positions = ioutils.reindex_by_longest( - dfs=trial_positions, - return_format="dataframe", - names="trial", - ) - fr = ioutils.reindex_by_longest( - dfs=rec_trials_fr[stream_id], - return_format="dataframe", - names=["trial", "unit"], - ) - fr = fr.reorder_levels(["unit", "trial"], axis=1) - fr = fr.sort_index(level=0, axis=1) + # unstack and concat horizontally + spiked = stacked_spiked.unstack( + level="trial", + sort=False, + ) + assert 0 + # get trials horizontally stacked spiked + spiked = ioutils.reindex_by_longest( + dfs=rec_trials_spiked[stream_id], + return_format="dataframe", + names=["trial", "unit"], + ) + spiked = spiked.reorder_levels(["unit", "trial"], axis=1) + spiked = spiked.sort_index(level=0, axis=1) - spiked = ioutils.reindex_by_longest( - dfs=rec_trials_spiked[stream_id], - return_format="dataframe", - names=["trial", "unit"], - ) + output[stream_id] = pd.concat( + {"spiked": spiked, "positions": positions}, + axis=1, + names=["data"], + ) + + elif data == "trial_rate": + fr = ioutils.reindex_by_longest( + dfs=rec_trials_fr[stream_id], + return_format="dataframe", + names=["trial", "unit"], + ) + fr = fr.reorder_levels(["unit", "trial"], axis=1) + fr = fr.sort_index(level=0, axis=1) + + output[stream_id] = pd.concat( + {"fr": fr, "positions": positions}, + axis=1, + names=["data"], + ) - return {"fr": fr, "positions": positions} + # concat output into dataframe before cache + df = pd.concat( + objs=output, + axis=1, + names=["stream_id"], + ) + assert 0 + return df def select_units( From 2562a2b119d9728fabdf3cad017737f2e223b876 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:33:08 +0100 Subject: [PATCH 240/658] add missing bracket --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4123513..55baaa5 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3121,7 +3121,7 @@ def bin_aligned_trials( # get stream spiked stream_spiked = spiked[stream_id]["spiked"] # get stream positions - positions = spiked[stream_id["positions"] + positions = spiked[stream_id]["positions"] # get stream firing rates stream_fr = fr[stream_id]["fr"] From eb4157e4d791b055ef424f1f4af285e548f7c109 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 1 Apr 2025 19:33:23 +0100 Subject: [PATCH 241/658] chance saved along with aligning to trial_times, remove this func --- pixels/behaviours/base.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 55baaa5..8e76ad1 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3341,34 +3341,3 @@ def bin_aligned_trials( print(f"> Output saved at {output_count_path}.") return None - - - @_cacheable - def get_spike_chance( - self, label, event, end_event=None, sigma=None, units=None, - ): - # get aligned firing rates and positions - trials = self.align_trials( - units=units, # NOTE: ALWAYS the first arg - label=label, - event=event, - data="trial_times", - sigma=sigma, - end_event=end_event, - ) - - assert 0 - # TODO mar 5 2025: - # how to get spike times then get chance? and cache them? - - # get chance data - chance_data = xut.get_spike_chance( - spiked=pd.concat(rec_trials_spiked[stream_id], axis=0), - sigma=sigma, - sample_rate=self.SAMPLE_RATE, - spiked_chance_path=s_chance_path, - fr_chance_path=fr_chance_path, - chance_df_path=chance_df_path, - ) - - return chance_data From 50440c9be7dd8256819b4236700845716888bb6e Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 13:58:46 +0100 Subject: [PATCH 242/658] separate spiked chance and fr chance data cuz they r big --- pixels/ioutils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 903642b..305cc51 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -133,8 +133,11 @@ def get_data_files(data_dir, session_name): f"{session_name}_{stream_id[:-3]}_fr_shuffled.bin" ) # .h5 files - pixels[stream_id]["shuffled"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_shuffled.h5" + pixels[stream_id]["spiked_shuffled"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_spiked_shuffled.h5" + ) + pixels[stream_id]["fr_shuffled"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_fr_shuffled.h5" ) # old catgt data From 1ebb278d6438466c17094bcfdbae0c5cf59ab126 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 13:59:08 +0100 Subject: [PATCH 243/658] allows taking key as arg --- pixels/ioutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 305cc51..e727fb5 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -340,7 +340,7 @@ def save_ndarray_as_video(video, path, frame_rate, dims=None): raise PixelsError(f"Video creation failed: {path}") -def read_hdf5(path): +def read_hdf5(path, key="df"): """ Read a dataframe from a h5 file. @@ -356,7 +356,7 @@ def read_hdf5(path): """ df = pd.read_hdf( path_or_buf=path, - key="df", + key=key, ) return df From d87d1f913cd0c94b8334b943f72649d67521245e Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 13:59:31 +0100 Subject: [PATCH 244/658] use df.unstack to turn long format into wide format more efficiently --- pixels/ioutils.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index e727fb5..7a070e4 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -752,24 +752,27 @@ def reindex_by_longest(dfs, return_format="array", names=None): === np.array or pd.DataFrame. """ - # align all trials by index - indices = list(set().union( - *[df.index for df in dfs.values()]) - ) + # stack dfs vertically + stacked_df = pd.concat(dfs, axis=0) + + # set index name + if idx_names: + stacked_df.index.names = idx_names - # reindex by the longest - reidx_dfs = {key: df.reindex(index=indices) - for key, df in dfs.items()} + # unstack df at level + df = stacked_df.unstack(level=level, sort=sort) if return_format == "array": # stack df values into np array - output = np.stack( - [df.values for df in reidx_dfs.values()], - axis=-1, - ) + output = df.values.squeeze() + elif return_format == "dataframe": - # concatenate dfs - output = pd.concat(reidx_dfs, axis=1, names=names) + if col_names: + if isinstance(stacked_df, pd.Series): + stacked_df.columns.names = col_names + elif isinstance(stacked_df, pd.DataFrame): + stacked_df.name = col_names[0] + output = stacked_df return output From 840563e66e7390f5ac33504ef6bbf5502e9c6de9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 13:59:59 +0100 Subject: [PATCH 245/658] allows to define index and col names --- pixels/ioutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 7a070e4..89a8e0d 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -736,7 +736,8 @@ def stream_video(video, length=None): if length == 0: break -def reindex_by_longest(dfs, return_format="array", names=None): +def reindex_by_longest(dfs, idx_names=None, col_names=None, level=0, sort=True, + return_format="array"): """ params === From c23d87adbe0ecdcba4c4ac24c42ac627d2ab71da Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:00:26 +0100 Subject: [PATCH 246/658] more clear print --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3ad0e56..c6a6063 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -711,7 +711,7 @@ def _save_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, future.result() else: print("\n> Memmaps already created, only need to convert into " - "dataframes.") + "dataframes and save.") # convert it to dataframe and save it save_chance( From 68ecce051d528487bd6f0f7815d65390cadb1154 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:01:06 +0100 Subject: [PATCH 247/658] put paths in front to allow dict unpacking --- pixels/pixels_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c6a6063..3e8f621 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -632,21 +632,21 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, return None -def save_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, - fr_memmap_path, chance_df_path, repeats=100): - if not chance_df_path.exists(): +def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, + fr_df_path, spiked, sigma, sample_rate, repeats=100): + if not fr_df_path.exists(): # save spike chance data if does not exists _save_spike_chance( - spiked, sigma, sample_rate, spiked_memmap_path, fr_memmap_path, - chance_df_path, repeats) + spiked_memmap_path, fr_memmap_path, spiked_df_path, fr_df_path, + spiked, sigma, sample_rate, repeats) else: - print(f"> Spike chance already saved at {chance_df_path}, continue.") + print(f"> Spike chance already saved at {fr_df_path}, continue.") return None -def _save_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, - fr_memmap_path, chance_df_path, repeats=100): +def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, + fr_df_path, spiked, sigma, sample_rate, repeats): """ Implementation of saving chance level spike data. """ From dc279892dc124298b97684a940f0013a4211bf97 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:01:49 +0100 Subject: [PATCH 248/658] separate spiked and fr chance data --- pixels/pixels_utils.py | 67 ++++++++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3e8f621..2ec9656 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -719,11 +719,12 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, orig_col=spiked.columns, spiked_memmap_path=spiked_memmap_path, fr_memmap_path=fr_memmap_path, - chance_df_path=chance_df_path, + spiked_df_path=spiked_df_path, + fr_df_path=fr_df_path, d_shape=d_shape, ) - print(f"\n> Chance data saved to {chance_df_path}.") + print(f"\n> Chance data saved to {fr_df_path}.") return None @@ -778,22 +779,30 @@ def _convert_to_df(orig_idx, orig_col, memmap_path, df_path, d_shape, d_type, path=df_path, df=df, key=name, - mode="a", + mode="w", ) - assert 0 + del df return None def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, - chance_df_path, d_shape): - # TODO mar 31 2025: test _convert_to_df + spiked_df_path, fr_df_path, d_shape): + """ + Saving chance data to dataframe. + + params + === + orig_idx: pandas + """ + print(f"\n> Saving chance data...") + # get chance spiked df _convert_to_df( orig_idx=orig_idx, orig_col=orig_col, memmap_path=spiked_memmap_path, - df_path=chance_df_path, + df_path=spiked_df_path, d_shape=d_shape, d_type=np.int16, name="spiked", @@ -803,23 +812,49 @@ def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, orig_idx=orig_idx, orig_col=orig_col, memmap_path=fr_memmap_path, - df_path=chance_df_path, + df_path=fr_df_path, d_shape=d_shape, d_type=np.float32, name="fr", ) - assert 0 - # TODO mar 31 2025: does it work or does it give memory error? return None def get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, - fr_memmap_path, chance_df_path, repeats=100): - if not chance_df_path.exists(): + fr_memmap_path, fr_df_path, repeats=100): + if not fr_df_path.exists(): # save spike chance data if does not exists - save_spike_chance( - spiked, sigma, sample_rate, spiked_memmap_path, fr_memmap_path, - chance_df_path, repeats) + save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, + fr_df_path, spiked, sigma, sample_rate, repeats) + #else: + # d_shape = spiked.shape + (repeats,) + + # spiked_chance = _get_spike_chance( + # path=spiked_memmap_path, + # shape=d_shape, + # dtype=np.int16, + # overwrite=False, + # readonly=True, + # ) + + # fr_chance = _get_spike_chance( + # path=fr_memmap_path, + # shape=d_shape, + # dtype=np.float32, + # overwrite=False, + # readonly=True, + # ) + assert 0 + # TODO apr 2 2025: loading df is too big and gives memory error???? + + return read_hdf5(spiked_df_path), read_hdf5(fr_df_path) + + +def _get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, + fr_memmap_path, fr_df_path, repeats=100): - return read_hdf5(chance_df_path) + # TODO apr 2 2025: + # for fr chance, use memmap, go to each repeat, unstack, bin, then save it + # to .npz for andrew + pass From 17f7e2cda3e5b141a57d2bd7838d6901202afa18 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:02:07 +0100 Subject: [PATCH 249/658] delete cpu copy to save memory --- pixels/pixels_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2ec9656..6ef3c00 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -762,6 +762,7 @@ def _convert_to_df(orig_idx, orig_col, memmap_path, df_path, d_shape, d_type, c_spiked = chance_memmap.copy() # reshape to 2D c_spiked_reshaped = c_spiked.reshape(d_shape[0], d_shape[1] * d_shape[2]) + del c_spiked # create hierarchical index col_idx = pd.MultiIndex.from_product( From 3bcb27b7e51e254bf522a733f74cd77230588c83 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:03:18 +0100 Subject: [PATCH 250/658] use unstacked implemented func --- pixels/behaviours/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 8e76ad1..64c5dce 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1913,9 +1913,12 @@ def _get_aligned_trials( # concat trial positions positions = ioutils.reindex_by_longest( dfs=trial_positions, + idx_names=["trial", "time"], + col_names=["position"], + level="trial", return_format="dataframe", - names="trial", ) + assert 0 if data == "trial_times": # get trials vertically stacked spiked From 3b5525acffee418346e74c08a2f67094df261adf Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:03:38 +0100 Subject: [PATCH 251/658] use dict unpack to make args more clear --- pixels/behaviours/base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 64c5dce..230f971 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1930,27 +1930,27 @@ def _get_aligned_trials( stacked_spiked.columns.names = ["unit"] # get chance data paths - s_memmap_path = self.interim /\ - stream_files["spiked_shuffled_memmap"] - fr_memmap_path = self.interim /\ - stream_files["fr_shuffled_memmap"] - chance_df_path = self.processed / stream_files["shuffled"] + paths = { + "spiked_memmap_path": self.interim /\ + stream_files["spiked_shuffled_memmap"], + "fr_memmap_path": self.interim /\ + stream_files["fr_shuffled_memmap"], + "spiked_df_path": self.processed / stream_files["spiked_shuffled"], + "fr_df_path": self.processed / stream_files["fr_shuffled"], + } # save chance data xut.save_spike_chance( + **paths, spiked=stacked_spiked, sigma=sigma, sample_rate=self.SAMPLE_RATE, - spiked_memmap_path=s_memmap_path, - fr_memmap_path=fr_memmap_path, - chance_df_path=chance_df_path, ) - assert 0 # unstack and concat horizontally spiked = stacked_spiked.unstack( level="trial", - sort=False, + sort=True, ) assert 0 # get trials horizontally stacked spiked From 3ab54c20d42c7c7042e8c6a3808b0718594d66a5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:05:49 +0100 Subject: [PATCH 252/658] load .h5 chance data --- pixels/behaviours/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 230f971..256caea 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3344,3 +3344,18 @@ def bin_aligned_trials( print(f"> Output saved at {output_count_path}.") return None + + def get_chance_data(self): + + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + + spiked_chance_path = self.processed / stream_files["spiked_shuffled"] + fr_chance_path = self.processed / stream_files["fr_shuffled"] + + fr_chance = ioutils.read_hdf5(fr_chance_path, key="fr") + assert 0 + spiked_chance = ioutils.read_hdf5(spiked_chance_path, key="spiked") + assert 0 + + return df From 7c744c2159bd7fbcab0f0c50e2315439811bc70b Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 2 Apr 2025 14:10:46 +0100 Subject: [PATCH 253/658] series only takes one name, df takes a list --- pixels/ioutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 89a8e0d..d4615e8 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -770,9 +770,9 @@ def reindex_by_longest(dfs, idx_names=None, col_names=None, level=0, sort=True, elif return_format == "dataframe": if col_names: if isinstance(stacked_df, pd.Series): - stacked_df.columns.names = col_names - elif isinstance(stacked_df, pd.DataFrame): stacked_df.name = col_names[0] + elif isinstance(stacked_df, pd.DataFrame): + stacked_df.columns.names = col_names output = stacked_df return output From bf2a44f42818fcc154eaa6deb9cee4a5c20f52a0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:27:47 +0100 Subject: [PATCH 254/658] add todo --- pixels/experiment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index cf9e981..7cf1c87 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -246,6 +246,8 @@ def align_trials(self, *args, units=None, **kwargs): names=["session", "trial", "scorer", "bodyparts", "coords"] ) + # TODO apr 3 2025: + # make sure trial_times is here too if "trial_rate" in kwargs.values(): frs = {} positions = {} From c11c515b403ef109216f936d70140a386c2648cc Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:28:01 +0100 Subject: [PATCH 255/658] allow defining format --- pixels/ioutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index d4615e8..9aaf30f 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -361,7 +361,7 @@ def read_hdf5(path, key="df"): return df -def write_hdf5(path, df, key="df", mode="w"): +def write_hdf5(path, df, key="df", mode="w", format="fixed"): """ Write a dataframe to an h5 file. @@ -388,6 +388,7 @@ def write_hdf5(path, df, key="df", mode="w"): path_or_buf=path, key=key, mode=mode, + format=format, complevel=9, #complib="bzip2", # slower but higher compression ratio complib="blosc:lz4hc", From b2fff160678096c70fbed3210bc2319d3c5cda1d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:28:41 +0100 Subject: [PATCH 256/658] allow defining return format when reindex; add helper funcs --- pixels/ioutils.py | 71 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 9aaf30f..b68e57b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -737,7 +737,7 @@ def stream_video(video, length=None): if length == 0: break -def reindex_by_longest(dfs, idx_names=None, col_names=None, level=0, sort=True, +def reindex_by_longest(dfs, idx_names=None, level=0, sort=True, return_format="array"): """ params @@ -754,27 +754,62 @@ def reindex_by_longest(dfs, idx_names=None, col_names=None, level=0, sort=True, === np.array or pd.DataFrame. """ - # stack dfs vertically - stacked_df = pd.concat(dfs, axis=0) - - # set index name - if idx_names: - stacked_df.index.names = idx_names - - # unstack df at level - df = stacked_df.unstack(level=level, sort=sort) - if return_format == "array": + # align all trials by index + indices = list(set().union( + *[df.index for df in dfs.values()]) + ) + # reindex by the longest + reidx_dfs = {key: df.reindex(index=indices) + for key, df in dfs.items()} # stack df values into np array - output = df.values.squeeze() + # NOTE: this create multidimensional data, different from if return + # format is df! + output = np.stack( + [df.values for df in reidx_dfs.values()], + axis=-1, + ) elif return_format == "dataframe": - if col_names: - if isinstance(stacked_df, pd.Series): - stacked_df.name = col_names[0] - elif isinstance(stacked_df, pd.DataFrame): - stacked_df.columns.names = col_names - output = stacked_df + if isinstance(dfs, dict): + # stack dfs vertically + stacked_df = pd.concat(dfs, axis=0) + # set index name + if idx_names: + stacked_df.index.names = idx_names + elif isinstance(dfs, pd.DataFrame): + stacked_df = dfs + + # unstack df at level + output = stacked_df.unstack(level=level, sort=sort) + del stacked_df return output +def is_nested_dict(d): + """ + Returns True if at least one value in dictionary d is a dict. + """ + return any(isinstance(v, dict) for v in d.values()) + + +def save_index_to_frame(df, path): + idx_df = df.index.to_frame(index=False) + write_hdf5( + path=path, + df=idx_df, + key="multiindex", + ) + # NOTE: to reconstruct: + # recs_idx = pd.MultiIndex.from_frame(idx_df) + + +def save_cols_to_frame(df, path): + col_df = df.columns.to_frame(index=False) + write_hdf5( + path=path, + df=col_df, + key="cols", + ) + # NOTE: to reconstruct: + # df.columns = col_df.values From 7b23fff0b6c15705eea587b1e0fd33b2007ef213 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:30:33 +0100 Subject: [PATCH 257/658] add file names for shuffled index, column, and shape --- pixels/ioutils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index b68e57b..e06ddc6 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -132,6 +132,15 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["fr_shuffled_memmap"] = base_name.with_name( f"{session_name}_{stream_id[:-3]}_fr_shuffled.bin" ) + pixels[stream_id]["shuffled_shape"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_shuffled_shape.json" + ) + pixels[stream_id]["shuffled_index"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_shuffled_index.h5" + ) + pixels[stream_id]["shuffled_columns"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_shuffled_columns.h5" + ) # .h5 files pixels[stream_id]["spiked_shuffled"] = base_name.with_name( f"{session_name}_{stream_id[:-3]}_spiked_shuffled.h5" From 249e6964b0d8c6bbc68efe38e6e7cd0e0998ce72 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:31:17 +0100 Subject: [PATCH 258/658] add todo --- pixels/behaviours/base.py | 63 ++++++++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 256caea..e46e6c9 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -109,6 +109,11 @@ def func(*args, **kwargs): keys = store.keys() # create df as a dictionary to hold all dfs df = {} + # TODO apr 2 2025: for now the nested dict have keys in the + # format of `/imec0.ap/positions`, this will not be the case + # once i flatten files at the stream level rather than + # session level, i.e., every pixels related cache will have + # stream id in their name. for key in keys: # read current df data = store[key] @@ -3345,17 +3350,59 @@ def bin_aligned_trials( return None - def get_chance_data(self): + def get_chance_data(self, time_bin, pos_bin): streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - spiked_chance_path = self.processed / stream_files["spiked_shuffled"] - fr_chance_path = self.processed / stream_files["fr_shuffled"] - - fr_chance = ioutils.read_hdf5(fr_chance_path, key="fr") - assert 0 - spiked_chance = ioutils.read_hdf5(spiked_chance_path, key="spiked") + paths = { + "spiked_memmap_path": self.interim /\ + stream_files["spiked_shuffled_memmap"], + "fr_memmap_path": self.interim /\ + stream_files["fr_shuffled_memmap"], + "memmap_shape_path": self.interim /\ + stream_files["shuffled_shape"], + "idx_path": self.interim / stream_files["shuffled_index"], + "col_path": self.interim /\ + stream_files["shuffled_columns"], + } + + # TODO apr 3 2025: how the fuck to get positions here???? + # TEMP: get it manually... + # light + pos_path = self.interim /\ + "cache/align_trials_all_trial_times_725_1_100_512.h5" + # dark + #pos_path = self.interim /\ + # "cache/align_trials_all_trial_times_1322_1_100_512.h5" + + with pd.HDFStore(pos_path, "r") as store: + # list all keys + keys = store.keys() + # create df as a dictionary to hold all dfs + df = {} + # TODO apr 2 2025: for now the nested dict have keys in the + # format of `/imec0.ap/positions`, this will not be the case + # once i flatten files at the stream level rather than + # session level, i.e., every pixels related cache will have + # stream id in their name. + for key in keys: + # read current df + data = store[key] + # remove "/" in key + key_name = key.lstrip("/") + # use key name as dict key + df[key_name] = data + positions = df[f"{stream_id[:-3]}/positions"] + + xut.get_spike_chance( + sample_rate=self.SAMPLE_RATE, + positions=positions, + time_bin=time_bin, + pos_bin=pos_bin, + **paths, + ) assert 0 + #spiked_chance = ioutils.read_hdf5(spiked_chance_path, key="spiked") - return df + return None From 301b2b971c7122b5709c72f477aca9a301df63b9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:32:15 +0100 Subject: [PATCH 259/658] if nested dict write them in a separate way --- pixels/behaviours/base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e46e6c9..54fa9de 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -129,7 +129,19 @@ def func(*args, **kwargs): output.touch() else: # allows to save multiple dfs in a dict in one hdf5 file - if isinstance(df, dict): + if ioutils.is_nested_dict(df): + for stream_id, nested_dict in df.items(): + # NOTE: we remove `.ap` in stream id cuz having `.`in + # the key name get problems + for name, values in nested_dict.items(): + key = f"/{stream_id}/{name}" + ioutils.write_hdf5( + path=output, + df=values, + key=key, + mode="a", + ) + elif isinstance(df, dict): for name, values in df.items(): ioutils.write_hdf5( path=output, From 712797dfb0667d008e559b2f37153427baca2ea3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:33:26 +0100 Subject: [PATCH 260/658] drop .ap in stream id so that it does not cause pandas probs --- pixels/behaviours/base.py | 76 +++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 54fa9de..b56f241 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1827,9 +1827,10 @@ def _get_aligned_trials( streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = stream_id[:-3] # allows multiple streams of recording, i.e., multiple probes - rec_trials_fr[stream_id] = {} - rec_trials_spiked[stream_id] = {} + rec_trials_fr[stream] = {} + rec_trials_spiked[stream] = {} # Account for multiple raw data files meta = self.ap_meta[stream_num] @@ -1908,13 +1909,13 @@ def _get_aligned_trials( spiked = spiked.iloc[scan_pad: -scan_pad] # reset index to zero at the beginning of the trial rates.reset_index(inplace=True, drop=True) - rec_trials_fr[stream_id][trial_ids[i]] = rates + rec_trials_fr[stream][trial_ids[i]] = rates spiked.reset_index(inplace=True, drop=True) - rec_trials_spiked[stream_id][trial_ids[i]] = spiked + rec_trials_spiked[stream][trial_ids[i]] = spiked trial_pos.reset_index(inplace=True, drop=True) trial_positions[trial_ids[i]] = trial_pos - #if not rec_trials_fr[stream_id]: + #if not rec_trials_fr[stream]: # return None # TODO feb 28 2025: @@ -1931,21 +1932,29 @@ def _get_aligned_trials( positions = ioutils.reindex_by_longest( dfs=trial_positions, idx_names=["trial", "time"], - col_names=["position"], level="trial", return_format="dataframe", ) - assert 0 if data == "trial_times": # get trials vertically stacked spiked stacked_spiked = pd.concat( - rec_trials_spiked[stream_id], + rec_trials_spiked[stream], axis=0, ) stacked_spiked.index.names = ["trial", "time"] stacked_spiked.columns.names = ["unit"] + # save index and columns to reconstruct df for shuffled data + ioutils.save_index_to_frame( + df=stacked_spiked, + path=self.interim / stream_files["shuffled_index"], + ) + ioutils.save_cols_to_frame( + df=stacked_spiked, + path=self.interim / stream_files["shuffled_columns"], + ) + # get chance data paths paths = { "spiked_memmap_path": self.interim /\ @@ -1959,55 +1968,46 @@ def _get_aligned_trials( # save chance data xut.save_spike_chance( **paths, - spiked=stacked_spiked, sigma=sigma, sample_rate=self.SAMPLE_RATE, + spiked=stacked_spiked, ) - # unstack and concat horizontally - spiked = stacked_spiked.unstack( - level="trial", - sort=True, - ) - assert 0 # get trials horizontally stacked spiked spiked = ioutils.reindex_by_longest( - dfs=rec_trials_spiked[stream_id], + dfs=stacked_spiked, + level="trial", return_format="dataframe", - names=["trial", "unit"], ) - spiked = spiked.reorder_levels(["unit", "trial"], axis=1) - spiked = spiked.sort_index(level=0, axis=1) - output[stream_id] = pd.concat( - {"spiked": spiked, "positions": positions}, - axis=1, - names=["data"], - ) + output[stream] = { + "spiked": spiked, + "positions": positions, + } elif data == "trial_rate": + # TODO apr 2 2025: make sure this reindex_by_longest works + assert 0 fr = ioutils.reindex_by_longest( - dfs=rec_trials_fr[stream_id], + dfs=rec_trials_fr[stream], return_format="dataframe", - names=["trial", "unit"], + idx_names=["trial", "unit"], ) fr = fr.reorder_levels(["unit", "trial"], axis=1) fr = fr.sort_index(level=0, axis=1) - output[stream_id] = pd.concat( - {"fr": fr, "positions": positions}, - axis=1, - names=["data"], - ) + output[stream] = { + "fr": fr, + "positions": positions, + } # concat output into dataframe before cache - df = pd.concat( - objs=output, - axis=1, - names=["stream_id"], - ) - assert 0 - return df + #df = pd.concat( + # objs=output, + # axis=1, + # names=["stream"], + #) + return output def select_units( From 7f080d4330bf31459da520bfb0af8a094538e427 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:34:44 +0100 Subject: [PATCH 261/658] add todo --- pixels/behaviours/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b56f241..daf3658 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3341,6 +3341,8 @@ def bin_aligned_trials( # stack df values into np array # reshape into trials x units x bins + assert 0 + # TODO apr 2 2025: make sure this reindex_by_longest works bin_count_arr = ioutils.reindex_by_longest(bin_counts).T bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T From 22b4535f3842df6f5b45d7c9718ea90a8b55505a Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:35:43 +0100 Subject: [PATCH 262/658] put type of data as second arg in align trials --- pixels/behaviours/base.py | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index daf3658..5b58116 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3004,9 +3004,9 @@ def get_positional_rate( # get aligned firing rates and positions trials = self.align_trials( units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg label=label, event=event, - data="trial_rate", sigma=sigma, end_event=end_event, ) @@ -3117,40 +3117,49 @@ def bin_aligned_trials( bin_frs = {} bin_counts = {} + bin_counts_chance = {} # get aligned spiked and positions spiked = self.align_trials( units=units, # NOTE: ALWAYS the first arg + data="trial_times", # NOTE: ALWAYS the second arg label=label, event=event, - data="trial_times", - sigma=sigma, - end_event=end_event, - ) - fr = self.align_trials( - units=units, # NOTE: ALWAYS the first arg - label=label, - event=event, - data="trial_rate", sigma=sigma, end_event=end_event, ) + #fr = self.align_trials( + # units=units, # NOTE: ALWAYS the first arg + # data="trial_rate", # NOTE: ALWAYS the second arg + # label=label, + # event=event, + # sigma=sigma, + # end_event=end_event, + #) streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + key = f"{stream_id[:-3]}/" # get stream spiked - stream_spiked = spiked[stream_id]["spiked"] + stream_spiked = spiked[key + "spiked"] # get stream positions - positions = spiked[stream_id]["positions"] + positions = spiked[key + "positions"] # get stream firing rates - stream_fr = fr[stream_id]["fr"] + #stream_fr = fr[stream]["fr"] - bin_frs[stream_id] = {} - bin_counts[stream_id] = {} + assert 0 + spiked_chance_path = self.processed / stream_files["spiked_shuffled"] + spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") + + bin_frs[stream] = {} + bin_counts[stream] = {} + bin_counts_chance[stream] = {} for trial in positions.columns.unique(): counts = stream_spiked.xs(trial, level="trial", axis=1) - rates = stream_fr.xs(trial, level="trial", axis=1) + #rates = stream_fr.xs(trial, level="trial", axis=1) trial_pos = positions[trial] + assert 0 + counts = stream_spiked.xs(trial, level="trial", axis=1).dropna() # get bin spike count bin_counts[stream_id][trial] = self.bin_vr_trial( From 571061f35014790c28244d8252b162a318b16b9c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:36:09 +0100 Subject: [PATCH 263/658] add imports --- pixels/pixels_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 6ef3c00..18da43d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -13,7 +13,8 @@ import spikeinterface.qualitymetrics as sqm import pixels.signal_utils as signal -from pixels.ioutils import write_hdf5 +from pixels.ioutils import write_hdf5, reindex_by_longest +from pixels.error import PixelsError from common_utils.math_utils import random_sampling from common_utils.file_utils import init_memmap, read_hdf5 From 296ee5877fffbc7cc07a21844286907c0fe48d8f Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:36:31 +0100 Subject: [PATCH 264/658] add import --- pixels/pixels_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 18da43d..3d2e32f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import json import numpy as np import pandas as pd From dab869cfd109834a7cd1b2296ac775bbd82e7216 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:36:43 +0100 Subject: [PATCH 265/658] initiate rng everytime with chance worker to avoid having the same seed --- pixels/pixels_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3d2e32f..bee6f36 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -28,9 +28,6 @@ ) si.set_global_job_kwargs(**job_kwargs) -# initiate random number generator -rng = np.random.default_rng() - def load_raw(paths, stream_id): """ Load raw recording file from spikeglx. @@ -556,6 +553,9 @@ def _permute_spikes_n_convolve_fr(array, sigma, sample_rate): random_fr: convolved firing rate from shuffled spike boolean for each unit. """ + # initiate random number generator every time to avoid same results from the + # same seeding + rng = np.random.default_rng() # permutate columns random_spiked = rng.permuted(array, axis=0) # convolve into firing rate From 0ff40fc2b6413c3c94bdac8732d5a6506a3e01f0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:49:31 +0100 Subject: [PATCH 266/658] use concat_spiked_path --- pixels/pixels_utils.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index bee6f36..d921b1f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -569,7 +569,7 @@ def _permute_spikes_n_convolve_fr(array, sigma, sample_rate): def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, - spiked_memmap_path, fr_memmap_path): + spiked_memmap_path, fr_memmap_path, concat_spiked_path): """ Worker that computes one set of spiked and fr values. @@ -596,7 +596,7 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, print(f"Processing repeat {i}...") # open readonly memmap spiked = init_memmap( - path=spiked_memmap_path.parent/"temp_spiked.bin", + path=concat_spiked_path, shape=spiked_shape, dtype=np.int16, overwrite=False, @@ -635,12 +635,14 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, - fr_df_path, spiked, sigma, sample_rate, repeats=100): - if not fr_df_path.exists(): + fr_df_path, sigma, sample_rate, repeats=100, spiked=None, + spiked_shape=None, concat_spiked_path=None): + if fr_df_path.exists(): # save spike chance data if does not exists _save_spike_chance( - spiked_memmap_path, fr_memmap_path, spiked_df_path, fr_df_path, - spiked, sigma, sample_rate, repeats) + spiked_memmap_path, fr_memmap_path, spiked_df_path, fr_df_path, sigma, + sample_rate, repeats, spiked, spiked_shape, + concat_spiked_path) else: print(f"> Spike chance already saved at {fr_df_path}, continue.") @@ -648,19 +650,24 @@ def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, - fr_df_path, spiked, sigma, sample_rate, repeats): + fr_df_path, sigma, sample_rate, repeats, spiked, + spiked_shape, concat_spiked_path): """ Implementation of saving chance level spike data. """ import concurrent.futures - # get export data shape - spiked_shape = spiked.shape - d_shape = spiked.shape + (repeats,) - - if not fr_memmap_path.exists(): + # save spiked to memmap if not yet + # TODO apr 9 2025: if i have temp_spiked, how to get its shape? do i need + # another input arg??? this is to run it again without get the concat spiked + # again... + if spiked is None: + assert concat_spiked_path.exists() + assert spiked_shape is not None + else: + concat_spiked_path = spiked_memmap_path.parent/"temp_spiked.bin" spiked_memmap = init_memmap( - path=spiked_memmap_path.parent/"temp_spiked.bin", + path=concat_spiked_path, shape=spiked.shape, dtype=np.int16, overwrite=True, @@ -705,6 +712,7 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, chance_data_shape=d_shape, spiked_memmap_path=spiked_memmap_path, fr_memmap_path=fr_memmap_path, + concat_spiked_path=concat_spiked_path, ) futures.append(future) From 508d8c3c2c39a60b61de03a2ae50c1a4f5d89682 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 17:59:18 +0100 Subject: [PATCH 267/658] move bin_vr_trial to pixels utils --- pixels/behaviours/base.py | 53 --------------------------------- pixels/pixels_utils.py | 62 ++++++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 61 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 5b58116..db7f7be 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2929,59 +2929,6 @@ def get_aligned_spike_rate_CI( return df - def bin_vr_trial(self, data, positions, time_bin, pos_bin, bin_method="mean"): - """ - Bin virtual reality trials by given temporal bin and positional bin. - - params - === - data: pandas dataframe, neural data needed binning. - - positions: pandas dataframe, position of current trial. - - time_bin: str, temporal bin for neural data. - - pos_bin: int, positional bin for positions. - - bin_method: str, method to concatenate data within each temporal bin. - "mean": taking the mean of all frames. - "sum": taking sum of all frames. - """ - data = data.copy() - positions = positions.copy() - - # convert index to datetime index for resampling - isi = (1 / self.SAMPLE_RATE) * 1000 - data.index = pd.to_timedelta( - arg=data.index * isi, - unit="ms", - ) - - # set position index too - positions.index = data.index - - # resample to 100ms bin, and get position mean - mean_pos = positions.resample(time_bin).mean() - - if bin_method == "sum": - # resample to Xms bin, and get sum - bin_data = data.resample(time_bin).sum() - elif bin_method == "mean": - # resample to Xms bin, and get mean - bin_data = data.resample(time_bin).mean() - - # add position here to bin together - bin_data['positions'] = mean_pos.values - # add bin positions - bin_pos = mean_pos // pos_bin + 1 - bin_data['bin_pos'] = bin_pos.values - - # use numeric index - bin_data.reset_index(inplace=True, drop=True) - - return bin_data - - @_cacheable def get_positional_rate( self, label, event, end_event=None, sigma=None, units=None, diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d921b1f..4d5d664 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -856,16 +856,62 @@ def get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, # overwrite=False, # readonly=True, # ) - assert 0 - # TODO apr 2 2025: loading df is too big and gives memory error???? - - return read_hdf5(spiked_df_path), read_hdf5(fr_df_path) - - -def _get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, - fr_memmap_path, fr_df_path, repeats=100): # TODO apr 2 2025: # for fr chance, use memmap, go to each repeat, unstack, bin, then save it # to .npz for andrew pass + + +def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, + bin_method="mean"): + """ + Bin virtual reality trials by given temporal bin and positional bin. + + params + === + data: pandas dataframe, neural data needed binning. + + positions: pandas dataframe, position of current trial. + + time_bin: str, temporal bin for neural data. + + pos_bin: int, positional bin for positions. + + bin_method: str, method to concatenate data within each temporal bin. + "mean": taking the mean of all frames. + "sum": taking sum of all frames. + """ + data = data.copy() + positions = positions.copy() + + # convert index to datetime index for resampling + isi = (1 / sample_rate) * 1000 + data.index = pd.to_timedelta( + arg=data.index * isi, + unit="ms", + ) + + # set position index too + positions.index = data.index + + # resample to 100ms bin, and get position mean + mean_pos = positions.resample(time_bin).mean() + + if bin_method == "sum": + # resample to Xms bin, and get sum + bin_data = data.resample(time_bin).sum() + elif bin_method == "mean": + # resample to Xms bin, and get mean + bin_data = data.resample(time_bin).mean() + + # add position here to bin together + bin_data['positions'] = mean_pos.values + # add bin positions + bin_pos = mean_pos // pos_bin + 1 + bin_data['bin_pos'] = bin_pos.values + + # use numeric index + bin_data.reset_index(inplace=True, drop=True) + + return bin_data From 65cfe949b34146d8fca169fa470beb8a5fc77ca5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 18:04:41 +0100 Subject: [PATCH 268/658] define dshape after spiked_shape defined --- pixels/pixels_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4d5d664..77d6bfb 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -677,6 +677,16 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, spiked_memmap.flush() del spiked_memmap + # get spiked data shape + spiked_shape = spiked.shape + + # get export data shape + d_shape = spiked_shape + (repeats,) + # TODO apr 9 2025 save dshape to json + #with open(shape_json, "w") as f: + #json.dump(shape, f, indent=4) + + if not fr_memmap_path.exists(): # init chance spiked memmap chance_spiked = init_memmap( path=spiked_memmap_path, From 4cdca8e2ef0a1a995319e850c3417da2d9e1d576 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 18:05:22 +0100 Subject: [PATCH 269/658] do not save to .h5 cuz memory issue --- pixels/pixels_utils.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 77d6bfb..6362194 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -734,17 +734,16 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, "dataframes and save.") # convert it to dataframe and save it - save_chance( - orig_idx=spiked.index, - orig_col=spiked.columns, - spiked_memmap_path=spiked_memmap_path, - fr_memmap_path=fr_memmap_path, - spiked_df_path=spiked_df_path, - fr_df_path=fr_df_path, - d_shape=d_shape, - ) - - print(f"\n> Chance data saved to {fr_df_path}.") + #save_chance( + # orig_idx=spiked.index, + # orig_col=spiked.columns, + # spiked_memmap_path=spiked_memmap_path, + # fr_memmap_path=fr_memmap_path, + # spiked_df_path=spiked_df_path, + # fr_df_path=fr_df_path, + # d_shape=d_shape, + #) + #print(f"\n> Chance data saved to {fr_df_path}.") return None From cfefa5fb9db0cbe6d098b8d6eeba501df5281ef9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 18:12:46 +0100 Subject: [PATCH 270/658] get spike chance and bin --- pixels/pixels_utils.py | 95 +++++++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 6362194..21d099d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -841,22 +841,87 @@ def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, return None -def get_spike_chance(spiked, sigma, sample_rate, spiked_memmap_path, - fr_memmap_path, fr_df_path, repeats=100): - if not fr_df_path.exists(): - # save spike chance data if does not exists - save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, - fr_df_path, spiked, sigma, sample_rate, repeats) - #else: - # d_shape = spiked.shape + (repeats,) +def get_spike_chance(sample_rate, positions, time_bin, pos_bin, + spiked_memmap_path, fr_memmap_path, memmap_shape_path, + idx_path, col_path): + if not fr_memmap_path.exists(): + raise PixelsError("\nHave you saved spike chance data yet?") + else: + # TODO apr 3 2025: we need to get positions here for binning!!! + # BUT HOW???? + _get_spike_chance(sample_rate, positions, time_bin, pos_bin, + spiked_memmap_path, fr_memmap_path, memmap_shape_path, + idx_path, col_path) - # spiked_chance = _get_spike_chance( - # path=spiked_memmap_path, - # shape=d_shape, - # dtype=np.int16, - # overwrite=False, - # readonly=True, - # ) + return None + + +def _get_spike_chance(sample_rate, positions, time_bin, pos_bin, + spiked_memmap_path, fr_memmap_path, memmap_shape_path, + idx_path, col_path): + + # TODO apr 9 2025: + # i do not need to save shape to file, all i need is unit count, repeat, + # so i load memmap without defining shape, then directly np.reshape(memmap, + # (-1, count, repeat))! + + with open(memmap_shape_path, "r") as f: + shape_data = json.load(f) + shape_list = shape_data.get("dshape", []) + d_shape = tuple(shape_list) + + spiked_chance = init_memmap( + path=spiked_memmap_path, + shape=d_shape, + dtype=np.int16, + overwrite=False, + readonly=True, + ) + + idx_df = read_hdf5(idx_path, key="multiindex") + idx = pd.MultiIndex.from_frame(idx_df) + trials = idx_df["trial"].unique() + col_df = read_hdf5(col_path, key="cols") + cols = pd.Index(col_df["unit"]) + + binned_shuffle = {} + temp = {} + # TODO apr 3 2025: implement multiprocessing here! + # get each repeat and create df + for r in range(d_shape[-1]): + shuffled = spiked_chance[:, :, r] + # create df + df = pd.DataFrame(shuffled, index=idx, columns=cols) + temp[r] = {} + for t in trials: + counts = df.xs(t, level="trial", axis=0) + trial_pos = positions.loc[:, t].dropna() + temp[r][t] = bin_vr_trial( + counts, + trial_pos, + sample_rate, + time_bin, + pos_bin, + bin_method="sum", + ) + binned_shuffle[r] = reindex_by_longest( + dfs=temp[r], + return_format="array", + ) + # concat + binned_shuffle_counts = np.stack( + list(binned_shuffle.values()), + axis=-1, + ) + shuffled_counts = { + "count": binned_shuffle_counts[:, :-2, ...], + "pos": binned_shuffle_counts[:, -2:, ...], + } + #count_path='/home/amz/running_data/npx/interim/20240812_az_VDCN09/20240812_az_VDCN09_imec0_light_all_spike_counts_shuffled_200ms_10cm.npz' + count_path='/home/amz/running_data/npx/interim/20240812_az_VDCN09/20240812_az_VDCN09_imec0_dark_all_spike_counts_shuffled_200ms_10cm.npz' + + np.savez_compressed(count_path, **shuffled_counts) + assert 0 # fr_chance = _get_spike_chance( # path=fr_memmap_path, From 826ece1ab52238d04b50685aecb873f72abe393d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 18:15:22 +0100 Subject: [PATCH 271/658] use bin_vr_trial in pixels_utils --- pixels/behaviours/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index db7f7be..c56d996 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3279,7 +3279,7 @@ def bin_aligned_trials( trial_pos.reset_index(inplace=True, drop=True) # get bin firing rates - bin_frs[i] = self.bin_vr_trial( + bin_frs[i] = xut.bin_vr_trial( data=rates, positions=trial_pos, time_bin=time_bin, @@ -3287,7 +3287,7 @@ def bin_aligned_trials( bin_method="mean", ) # get bin spike count - bin_counts[i] = self.bin_vr_trial( + bin_counts[i] = xut.bin_vr_trial( data=spiked, positions=trial_pos, time_bin=time_bin, From e3078890864b6becc1d55a70bcfc167b722a7616 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 10 Apr 2025 18:56:12 +0100 Subject: [PATCH 272/658] add spike sorting amp issue todo --- pixels/pixels_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 21d099d..e23132f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -350,6 +350,20 @@ def _sort_spikes(rec, output, ks_image_path, ks4_params): **ks4_params, ) + # TODO apr 10 2025: + # since this file is whitened, the amplitude of the signal is NOT the same + # as the original, and this might cause issue in calculating signal + # amplitude in spikeinterface. cuz in ks4 output, units amplitude is between + # 0-315, but in si it's between -4000 to 4000. + # POTENTIAL SOLUTIONS: + # 1. do what chris does, make another preprocessed recording just to build + # the sorting analyser, or + # 2. still use the temp_wh.dat from ks4, but check how ks4 handles amplitude + # and the unit of amplitude, correct it + # WHAT TO ACHIEVE: + # 1. without whitening, peak amplitude should be ~-70mV + # 2. with whitening, peak amplitude should be between -1 to 1 + # load ks preprocessed recording for # sorting analyser ks_preprocessed = se.read_binary( file_paths=output/"sorter_output/temp_wh.dat", From 564687236e583a86390adf8d2749ea69d9eb8bf0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 18:56:49 +0100 Subject: [PATCH 273/658] check key name length and load cache --- pixels/behaviours/base.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c56d996..6dbba61 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -104,23 +104,30 @@ def func(*args, **kwargs): except (KeyError, ValueError): # if key="df" is not found, then use HDFStore to list and read # all dfs + # create df as a dictionary to hold all dfs + df = {} with pd.HDFStore(output, "r") as store: # list all keys keys = store.keys() - # create df as a dictionary to hold all dfs - df = {} # TODO apr 2 2025: for now the nested dict have keys in the - # format of `/imec0.ap/positions`, this will not be the case + # format of `/imec0/positions`, this will not be the case # once i flatten files at the stream level rather than # session level, i.e., every pixels related cache will have # stream id in their name. for key in keys: - # read current df - data = store[key] - # remove "/" in key - key_name = key.lstrip("/") - # use key name as dict key - df[key_name] = data + # remove "/" in key and split + key_name = key.lstrip("/").split("/") + if len(key_name) == 1: + # use the only key name as dict key + df[key_name[0]] = store[key] + elif len(key_name) == 2: + # stream id is the first + stream = key_name[0] + # data name is the second + name = "/".join(key_name[1:]) + if stream not in df: + df[stream] = {} + df[stream][name] = store[key] print(f"> Cache loaded from {output}.") else: df = method(*args, **kwargs) From 5a0d97d1dd6e446afd57221010f8d954addead7c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 18:57:11 +0100 Subject: [PATCH 274/658] load brain surface depth yaml file --- pixels/behaviours/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 6dbba61..868fcdf 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -41,6 +41,8 @@ from pixels.error import PixelsError from pixels.constants import * +from common_utils.file_utils import load_yaml + if TYPE_CHECKING: from typing import Optional, Literal @@ -707,7 +709,16 @@ def preprocess_raw(self): f"in total of {self.stream_count} stream(s)" ) - stream_files["preprocessed"] = xut.preprocess_raw(rec) + # load brain surface depths + surface_depths = load_yaml( + path=self.find_file(stream_files["surface_depth"]), + ) + + # preprocess + stream_files["preprocessed"] = xut.preprocess_raw( + rec, + surface_depths, + ) return None From 6fdccab02576bcc7c879a3e90a41c99409b59aad Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 18:58:04 +0100 Subject: [PATCH 275/658] move stream & output up to allow empty return --- pixels/behaviours/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 868fcdf..39cb5f2 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1788,6 +1788,8 @@ def _get_aligned_trials( """ action_labels = self.get_action_labels()[0] + streams = self.files["pixels"] + output = {} if units is None: units = self.select_units() @@ -1827,6 +1829,12 @@ def _get_aligned_trials( # only take ends from selected trials selected_ends = trials[np.where(np.isin(trials, ends))[0]] end_t = timestamps[selected_ends] + if selected_starts.size == 0: + print(f"> No trials found with label {label} and event {event}, " + "output will be empty.") + for key in streams.keys(): + output[keys[:-3]] = {} + return output # use original trial id as trial index trial_ids = vr_data.iloc[selected_starts].trial_count.unique() @@ -1838,12 +1846,10 @@ def _get_aligned_trials( scan_durations = scan_ends - scan_starts cursor = 0 # In sample points - output = {} rec_trials_fr = {} rec_trials_spiked = {} trial_positions = {} - streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): stream = stream_id[:-3] # allows multiple streams of recording, i.e., multiple probes From 8e36c7c7fcefe9ce464d5030327a9d8f9e6c000c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 18:58:34 +0100 Subject: [PATCH 276/658] make sure reindexing has correct col and idx name --- pixels/behaviours/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 39cb5f2..d077e19 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2011,14 +2011,13 @@ def _get_aligned_trials( elif data == "trial_rate": # TODO apr 2 2025: make sure this reindex_by_longest works - assert 0 fr = ioutils.reindex_by_longest( dfs=rec_trials_fr[stream], + level="trial", + idx_names=["trial", "time"], + col_names=["unit"], return_format="dataframe", - idx_names=["trial", "unit"], ) - fr = fr.reorder_levels(["unit", "trial"], axis=1) - fr = fr.sort_index(level=0, axis=1) output[stream] = { "fr": fr, From 3da55c9699ddd26a791c7423fcad189930fe8af3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 18:59:42 +0100 Subject: [PATCH 277/658] use cached align_trial data to bin trials --- pixels/behaviours/base.py | 266 ++++++++------------------------------ 1 file changed, 53 insertions(+), 213 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d077e19..39dd687 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3098,248 +3098,88 @@ def bin_aligned_trials( sigma=sigma, end_event=end_event, ) - #fr = self.align_trials( - # units=units, # NOTE: ALWAYS the first arg - # data="trial_rate", # NOTE: ALWAYS the second arg - # label=label, - # event=event, - # sigma=sigma, - # end_event=end_event, - #) + fr = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = stream_id[:-3] + # define output path for binned spike rate + output_fr_path = self.interim/\ + f'cache/{self.name}_{stream}_{label}_{units}_{time_bin}_spike_rate.npz' + output_count_path = self.interim/\ + f'cache/{self.name}_{stream}_{label}_{units}_{time_bin}_spike_count.npz' + if output_count_path.exists(): + continue + key = f"{stream_id[:-3]}/" + print(f"\n> Binning trials from {stream_id}.") + # get stream spiked - stream_spiked = spiked[key + "spiked"] + stream_spiked = spiked[stream]["spiked"] # get stream positions - positions = spiked[key + "positions"] + positions = spiked[stream]["positions"] # get stream firing rates - #stream_fr = fr[stream]["fr"] + stream_fr = fr[stream]["fr"] - assert 0 - spiked_chance_path = self.processed / stream_files["spiked_shuffled"] - spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") + # TODO apr 11 2025: + # bin chance while bin data + #spiked_chance_path = self.processed / stream_files["spiked_shuffled"] + #spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") + #bin_counts_chance[stream_id] = {} - bin_frs[stream] = {} - bin_counts[stream] = {} - bin_counts_chance[stream] = {} + bin_frs[stream_id] = {} + bin_counts[stream_id] = {} for trial in positions.columns.unique(): - counts = stream_spiked.xs(trial, level="trial", axis=1) - #rates = stream_fr.xs(trial, level="trial", axis=1) - trial_pos = positions[trial] - assert 0 counts = stream_spiked.xs(trial, level="trial", axis=1).dropna() + rates = stream_fr.xs(trial, level="trial", axis=1).dropna() + trial_pos = positions[trial].dropna() # get bin spike count - bin_counts[stream_id][trial] = self.bin_vr_trial( + bin_counts[stream_id][trial] = xut.bin_vr_trial( data=counts, positions=trial_pos, + sample_rate=self.SAMPLE_RATE, time_bin=time_bin, pos_bin=pos_bin, bin_method="sum", ) # get bin firing rates - bin_frs[stream_id][trial] = self.bin_vr_trial( + bin_frs[stream_id][trial] = xut.bin_vr_trial( data=rates, positions=trial_pos, - time_bin=time_bin, - pos_bin=pos_bin, - bin_method="mean", - ) - assert 0 - - action_labels = self.get_action_labels()[0] - - # define output path for binned spike rate - output_fr_path = self.interim/\ - f'cache/{self.name}_{label}_{units}_{time_bin}_spike_rate.npz' - output_count_path = self.interim/\ - f'cache/{self.name}_{label}_{units}_{time_bin}_spike_count.npz' - - if output_count_path.exists() and output_fr_path.exists(): - print(f"> {self.name} {label} {units} {time_bin} .npz already " - "saved.") - return None - - print(f"> Binning data from {self.name} {label} {units} units to " - f"{time_bin}.") - - if units is None: - units = self.select_units() - - if not pos_bin is None: - behaviour_files = self.files["behaviour"] - # assume only one vr session for now - vr_dir = self.find_file(behaviour_files["vr_synched"][0]) - vr_data = ioutils.read_hdf5(vr_dir) - # get positions - positions = vr_data.position_in_tunnel - - #TODO: with multiple streams, spike times will be a list with multiple dfs, - #make sure old code does not break! - spikes = self.get_spike_times(use_si=True)[units] - # drop rows if all nans - spikes = spikes.dropna(how="all") - - # since each session has one behaviour session, now only one action - # label file - actions = action_labels["outcome"] - events = action_labels["events"] - # get timestamps index of behaviour in self.SAMPLE_RATE hz, to convert - # it to ms, do timestamps*1000/self.SAMPLE_RATE - timestamps = action_labels["timestamps"] - - # select frames of wanted trial type - trials = np.where(np.bitwise_and(actions, label))[0] - # map starts by event - starts = np.where(np.bitwise_and(events, event))[0] - # map starts by end event - ends = np.where(np.bitwise_and(events, end_event))[0] - - # only take starts from selected trials - selected_starts = trials[np.where(np.isin(trials, starts))[0]] - start_t = timestamps[selected_starts] - # only take ends from selected trials - selected_ends = trials[np.where(np.isin(trials, ends))[0]] - end_t = timestamps[selected_ends] - - # use original trial id as trial index - trial_ids = vr_data.iloc[selected_starts].trial_count.unique() - - # pad ends with 1 second extra to remove edge effects from convolution - scan_pad = self.SAMPLE_RATE - scan_starts = start_t - scan_pad - scan_ends = end_t + scan_pad - scan_durations = scan_ends - scan_starts - - cursor = 0 # In sample points - rec_trials = {} - trial_positions = {} - bin_frs = {} - bin_counts = {} - - streams = self.files["pixels"] - for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - # TODO jun 12 2024 skip other streams for now - if stream_num > 0: - continue - - # Account for multiple raw data files - meta = self.ap_meta[stream_num] - samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 - assert samples.is_integer() - in_SAMPLE_RATE_scale = (samples * self.SAMPLE_RATE)\ - / int(self.ap_meta[0]['imSampRate']) - cursor_duration = (cursor * self.SAMPLE_RATE)\ - / int(self.ap_meta[0]['imSampRate']) - rec_spikes = spikes[ - (cursor_duration <= spikes)\ - & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) - ] - cursor_duration - cursor += samples - - # Account for lag, in case the ephys recording was started before the - # behaviour - if not self._lag[stream_num] == None: - lag_start, _ = self._lag[stream_num] - else: - lag_start = timestamps[0] - - if lag_start < 0: - rec_spikes = rec_spikes + lag_start - - for i, start in enumerate(selected_starts): - # select spike times of current trial - trial_bool = (rec_spikes >= scan_starts[i])\ - & (rec_spikes <= scan_ends[i]) - trial = rec_spikes[trial_bool] - - # get position bin ids for current trial - trial_pos_bool = (positions.index >= start_t[i])\ - & (positions.index < end_t[i]) - trial_pos = positions[trial_pos_bool] - - # initiate binary spike times array for current trial - # NOTE: dtype must be float otherwise would get all 0 when - # passing gaussian kernel - times = np.zeros((scan_durations[i], len(units))).astype(float) - # use pixels time as spike index - idx = np.arange(scan_starts[i], scan_ends[i]) - # make it df, column name being unit id - spiked = pd.DataFrame(times, index=idx, columns=units) - - for unit in trial: - # get spike time for unit - u_times = trial[unit].values - # drop nan - u_times = u_times[~np.isnan(u_times)] - - # round spike times to use it as index - u_spike_idx = np.round(u_times).astype(int) - # make sure it does not exceed scan duration - if (u_spike_idx >= scan_ends[i]).any(): - beyonds = np.where(u_spike_idx >= scan_ends[i])[0] - u_spike_idx[beyonds] = idx[-1] - # make sure no double counted - u_spike_idx = np.unique(u_spike_idx) - - # set spiked to 1 - spiked.loc[u_spike_idx, unit] = 1 - - # convolve spike trains into spike rates - rates = signal.convolve_spike_trains( - times=spiked, - sigma=sigma, sample_rate=self.SAMPLE_RATE, - ) - # remove 1s padding from the start and end - rates = rates.iloc[scan_pad: -scan_pad] - spiked = spiked.iloc[scan_pad: -scan_pad] - # reset index to zero at the beginning of the trial - rates.reset_index(inplace=True, drop=True) - spiked.reset_index(inplace=True, drop=True) - trial_pos.reset_index(inplace=True, drop=True) - - # get bin firing rates - bin_frs[i] = xut.bin_vr_trial( - data=rates, - positions=trial_pos, time_bin=time_bin, pos_bin=pos_bin, bin_method="mean", ) - # get bin spike count - bin_counts[i] = xut.bin_vr_trial( - data=spiked, - positions=trial_pos, - time_bin=time_bin, - pos_bin=pos_bin, - bin_method="sum", - ) - # stack df values into np array - # reshape into trials x units x bins - assert 0 - # TODO apr 2 2025: make sure this reindex_by_longest works - bin_count_arr = ioutils.reindex_by_longest(bin_counts).T - bin_fr_arr = ioutils.reindex_by_longest(bin_frs).T - - # save bin_fr and bin_count, for alfredo & andrew - # use label as array key name - fr_to_save = { - "fr": bin_fr_arr[:, :-2, :], - "pos": bin_fr_arr[:, -2:, :], - } - np.savez_compressed(output_fr_path, **fr_to_save) - print(f"> Output saved at {output_fr_path}.") + # stack df values into np array + # reshape into trials x units x bins + bin_count_arr = ioutils.reindex_by_longest(bin_counts[stream_id]).T + bin_fr_arr = ioutils.reindex_by_longest(bin_frs[stream_id]).T - count_to_save = { - "count": bin_count_arr[:, :-2, :], - "pos": bin_count_arr[:, -2:, :], - } - np.savez_compressed(output_count_path, **count_to_save) - print(f"> Output saved at {output_count_path}.") + # save bin_fr and bin_count, for alfredo & andrew + # use label as array key name + fr_to_save = { + "fr": bin_fr_arr[:, :-2, :], + "pos": bin_fr_arr[:, -2:, :], + } + np.savez_compressed(output_fr_path, **fr_to_save) + print(f"> Output saved at {output_fr_path}.") + + count_to_save = { + "count": bin_count_arr[:, :-2, :], + "pos": bin_count_arr[:, -2:, :], + } + np.savez_compressed(output_count_path, **count_to_save) + print(f"> Output saved at {output_count_path}.") return None From 63b4cf30b0c2c7d7d73a42939309f6f2e843fbc1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 19:00:15 +0100 Subject: [PATCH 278/658] add brain surface depth yaml --- pixels/ioutils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index e06ddc6..8989658 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -49,7 +49,7 @@ def get_data_files(data_dir, session_name): "preprocessed": PosixPath("name.zarr"), "ap_downsampled": PosixPath("name.zarr"), "lfp_downsampled": PosixPath("name.zarr"), - "depth_info": PosixPath("name.json"), <== ?? + "surface_depth": PosixPath("name.yaml"), "sorting_analyser": PosixPath("name.zarr"), }, "imecN":{ @@ -112,8 +112,8 @@ def get_data_files(data_dir, session_name): ) # depth info of probe - pixels[stream_id]["depth_info"] = base_name.with_name( - f"{session_name}_{stream_id}_depth_info.h5" + pixels[stream_id]["surface_depth"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_surface_depth.yaml" ) pixels[stream_id]["clustered_channels"] = base_name.with_name( f"{session_name}_{stream_id}_channel_clustering_results.h5" From 805bf982e035755937308ba81b85313a09a87b03 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 19:00:37 +0100 Subject: [PATCH 279/658] add column names --- pixels/ioutils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 8989658..5455483 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -746,7 +746,7 @@ def stream_video(video, length=None): if length == 0: break -def reindex_by_longest(dfs, idx_names=None, level=0, sort=True, +def reindex_by_longest(dfs, idx_names=None, col_names=None, level=0, sort=True, return_format="array"): """ params @@ -786,6 +786,8 @@ def reindex_by_longest(dfs, idx_names=None, level=0, sort=True, # set index name if idx_names: stacked_df.index.names = idx_names + if col_names: + stacked_df.columns.names = col_names elif isinstance(dfs, pd.DataFrame): stacked_df = dfs From 5eefc20ac27f4cb29af9d3f35f8f539fc395c0f6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 19:01:10 +0100 Subject: [PATCH 280/658] get shank id if only one shank used in multishank probe --- pixels/pixels_utils.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index e23132f..2ba2b5f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1003,3 +1003,36 @@ def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, bin_data.reset_index(inplace=True, drop=True) return bin_data + +def get_shank_id_for_single_shank(rec): + # check probe type + ''' + npx 1.0: 0 + npx 2.0 alpha: 24 + npx 2.0 commercial: 2013 + ''' + probe_type = int(rec.get_annotation("probes_info")[0]["probe_type"]) + + # get channel x locations + ''' + shank 0: 0, 32 + shank 1: 250, 282 + shank 2: 500, 582 + shank 3: 750, 782 + ''' + if probe_type > 0: + x_locs = np.unique(rec.get_channel_locations()[:, 0]) + if np.all(x_locs < 200): + shank_id = 0 + elif np.all(x_locs > 200) and np.all(x_locs < 500): + shank_id = 1 + elif np.all(x_locs > 500) and np.all(x_locs < 700): + shank_id = 2 + elif np.all(x_locs > 700): + shank_id = 3 + + # get number of channels and set their group to shank id + ids = np.zeros(rec.channel_ids.shape).astype(int) + ids[:] = shank_id + + return ids From c7701a4e6296d396b6ad3188b831de51274119d5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 19:01:39 +0100 Subject: [PATCH 281/658] add arg brain surface depth --- pixels/pixels_utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2ba2b5f..1f28289 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -53,7 +53,7 @@ def load_raw(paths, stream_id): return rec -def preprocess_raw(rec): +def preprocess_raw(rec, surface_depths): shank_groups = rec.get_channel_groups() if not np.all(shank_groups == shank_groups[0]): preprocessed = [] @@ -61,12 +61,24 @@ def preprocess_raw(rec): groups = rec.split_by("group") for g, group in enumerate(groups.values()): print(f"> Preprocessing shank {g}") - cleaned = _preprocess_raw(group) + # get brain surface depth of shank + surface_depth = surface_depths[g] + cleaned = _preprocess_raw(group, surface_depth) preprocessed.append(cleaned) # aggregate groups together preprocessed = si.aggregate_channels(preprocessed) else: - preprocessed = _preprocess_raw(rec) + # check which shank used + group_id = get_shank_id_for_single_shank(rec) + unique_id = np.unique(group_id)[0] + print("\n> Single shank used in multishank probe, change group id into " + f"{unique_id}.") + # change the group id + rec.set_channel_groups(group_id) + # get brain surface depth of shank + surface_depth = surface_depths[unique_id] + # preprocess + preprocessed = _preprocess_raw(rec, surface_depth) # NOTE jan 16 2025: # BUG: cannot set dtype back to int16, units from ks4 will have @@ -77,7 +89,7 @@ def preprocess_raw(rec): return preprocessed -def _preprocess_raw(rec): +def _preprocess_raw(rec, surface_depth): """ Implementation of preprocessing on raw pixels data. """ From 408c8228f3b455e90f0a1010f9c49ab27732d608 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 19:03:07 +0100 Subject: [PATCH 282/658] remove outside channels by brain surface depth --- pixels/pixels_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1f28289..95e78a1 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -104,10 +104,20 @@ def _preprocess_raw(rec, surface_depth): outside_channels_location="top", ) labels, counts = np.unique(chan_labels, return_counts=True) - for label, count in zip(labels, counts): print(f"\t\t> Found {count} channels labelled as {label}.") - rec_clean = rec_ps.remove_channels(bad_chan_ids) + rec_removed = rec_ps.remove_channels(bad_chan_ids) + + # get channel group id and use it to index into brain surface channel depth + shank_id = np.unique(rec_removed.get_channel_groups())[0] + # get channel depths + chan_depths = rec_removed.get_channel_locations()[:, 1] + # get channel ids + chan_ids = rec_removed.channel_ids + # remove channels outside by using identified brain surface depths + outside_chan_ids = chan_ids[chan_depths > surface_depth] + rec_clean = rec_removed.remove_channels(outside_chan_ids) + print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") print("\t> step 3: do common median referencing.") # NOTE: dtype will be converted to float32 during motion correction From 8635a3112edc3753862259e272a8ba6f5a442a8d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 17 Apr 2025 19:03:55 +0100 Subject: [PATCH 283/658] add todo --- pixels/pixels_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 95e78a1..c46b255 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -567,6 +567,10 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): folder=curated_sa_dir, ) + # TODO apr 11 2025: + # after we curated units with quality metrics, there are still some noise in + # units, how do we use si to remove those ones identified qualitatively? + return None From 6792217138ee780e9e7afa87bc8966c736f48299 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 18 Apr 2025 15:08:46 +0100 Subject: [PATCH 284/658] when use si to select units, pass selected_units as arg to select units before extracting times --- pixels/behaviours/base.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 39dd687..4154b49 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -983,7 +983,6 @@ def sort_spikes(self, mc_method="dredge"): ks_image_path=ks_image_path, ks4_params=ks4_params, ) - assert 0 return None @@ -1590,7 +1589,7 @@ def get_lfp_data(self): return self._get_processed_data("_lfp_data", "lfp_processed") - def _get_si_spike_times(self): + def _get_si_spike_times(self, units): """ get spike times in second with spikeinterface """ @@ -1600,7 +1599,11 @@ def _get_si_spike_times(self): for stream_num, (stream_id, stream_files) in enumerate(streams.items()): sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser - sa = si.load_sorting_analyzer(sa_dir) + temp_sa = si.load_sorting_analyzer(sa_dir) + # select units + sorting = temp_sa.sorting.select_units(units) + sa = temp_sa.select_units(units) + sa.sorting = sorting times = {} # get spike train @@ -1623,7 +1626,7 @@ def _get_si_spike_times(self): return spike_times[0] # NOTE: only deal with one stream for now - def get_spike_times(self, remapped=False, use_si=False): + def get_spike_times(self, units, remapped=False, use_si=False): """ Returns the sorted spike times. @@ -1636,7 +1639,7 @@ def get_spike_times(self, remapped=False, use_si=False): for stream_num, stream in enumerate(range(len(spike_times))): if use_si: - spike_times[stream_num] = self._get_si_spike_times() + spike_times[stream_num] = self._get_si_spike_times(units) else: if remapped and stream_num > 0: times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' @@ -1804,7 +1807,7 @@ def _get_aligned_trials( #TODO: with multiple streams, spike times will be a list with multiple dfs, #make sure old code does not break! - spikes = self.get_spike_times(use_si=True)[units] + spikes = self.get_spike_times(units, use_si=True) # drop rows if all nans spikes = spikes.dropna(how="all") From 403ddd246b9f35705963608d98cec0360a046b1f Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 18 Apr 2025 15:09:28 +0100 Subject: [PATCH 285/658] typo --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 4154b49..d4780a1 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1836,7 +1836,7 @@ def _get_aligned_trials( print(f"> No trials found with label {label} and event {event}, " "output will be empty.") for key in streams.keys(): - output[keys[:-3]] = {} + output[key[:-3]] = {} return output # use original trial id as trial index From 7051355de99a045d7ab2074051435b45a80c4c49 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 18 Apr 2025 15:09:38 +0100 Subject: [PATCH 286/658] remove noisy units when loading curated sorting analyser --- pixels/behaviours/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d4780a1..f5dc0bc 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2083,7 +2083,15 @@ def select_units( stream_files = self.files["pixels"]["imec0.ap"] sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser - sa = si.load_sorting_analyzer(sa_dir) + temp_sa = si.load_sorting_analyzer(sa_dir) + # remove noisy units + noisy_units = load_yaml( + path=self.find_file(stream_files["noisy_units"]), + ) + # remove units from sorting and reattach to sa to keep properties + sorting = temp_sa.sorting.remove_units(remove_unit_ids=noisy_units) + sa = temp_sa.remove_units(remove_unit_ids=noisy_units) + sa.sorting = sorting # get units unit_ids = sa.unit_ids From 912b5570b07bda066f3b0b16361e41dc4d3609f1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 18 Apr 2025 15:10:33 +0100 Subject: [PATCH 287/658] return if no units found --- pixels/behaviours/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f5dc0bc..62ca8b9 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -3134,6 +3134,10 @@ def bin_aligned_trials( # get stream spiked stream_spiked = spiked[stream]["spiked"] + if stream_spiked.size == 0: + print(f"\n> No units found in {units}, continue.") + return None + # get stream positions positions = spiked[stream]["positions"] # get stream firing rates From b950eb77e0fb04d6aaa367e3fb93e137b821d969 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 18 Apr 2025 15:10:53 +0100 Subject: [PATCH 288/658] add manually selected noisy units yaml --- pixels/ioutils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 5455483..65a50ac 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -149,6 +149,11 @@ def get_data_files(data_dir, session_name): f"{session_name}_{stream_id[:-3]}_fr_shuffled.h5" ) + # noise in curated units + pixels[stream_id]["noisy_units"] = base_name.with_name( + f"{session_name}_{stream_id[:-3]}_noisy_units.yaml" + ) + # old catgt data pixels[stream_id]["CatGT_ap_data"].append( str(base_name).replace("t0", "tcat") From 599c7e95b8571397036f8c1cb12f515e266d7387 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 18 Apr 2025 15:11:13 +0100 Subject: [PATCH 289/658] select units in sorting obj and reattach to sorting analyser to keep sorting properties --- pixels/pixels_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c46b255..921af21 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -515,7 +515,10 @@ def _curate_sorting(sorting, recording, output): # get unit ids curated_unit_ids = list(good_qms.index) # select curated + curated_sorting = sa.sorting.select_units(curated_unit_ids) curated_sa = sa.select_units(curated_unit_ids) + # reattach curated sorting to curated_sa to keep sorting properties + curated_sa.sorting = curated_sorting return sa, curated_sa From eb416eff9a22ce60cfd2e535d30f56a5bff20e47 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:32:41 +0100 Subject: [PATCH 290/658] add depth info for mouse --- pixels/ioutils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 65a50ac..fb73f41 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -162,6 +162,12 @@ def get_data_files(data_dir, session_name): str(base_name).replace("t0", "tcat") ) + # histology + mouse_id = session_name.split("_")[-1] + pixels[stream_id]["depth_info"] = base_name.with_name( + f"{mouse_id}_depth_info.yaml" + ) + #pixels[stream_id]["spike_rate_processed"] = base_name.with_name( # f"spike_rate_{stream_id}.h5" #) From 680e2b8640a9aad8a12b77d29c5081af5b64901b Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:32:55 +0100 Subject: [PATCH 291/658] add doc --- pixels/pixels_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 921af21..cd9fd5e 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1,3 +1,6 @@ +""" +This module provides utilities for pixels data. +""" import multiprocessing as mp import json From f70d973c39adc2af80f695219bc3b6d0b1743cf9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:33:22 +0100 Subject: [PATCH 292/658] define level names --- pixels/pixels_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index cd9fd5e..c7d68a7 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -186,6 +186,8 @@ def detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html """ shank_groups = rec.get_channel_groups() + level_names = ["shank", "spike_properties"] + if not np.all(shank_groups == shank_groups[0]): # split by groups groups = rec.split_by("group") @@ -198,10 +200,16 @@ def detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): dfs, axis=1, keys=groups.keys(), - names=["shank", "spike_properties"] + names=level_names, ) else: - df = self._detect_n_localise_peaks(rec, loc_method) + df = _detect_n_localise_peaks(rec, loc_method) + # add shank level on top + shank_id = shank_groups[0] + df.columns = pd.MultiIndex.from_tuples( + [(shank_id, col) for col in df.columns], + names=level_names, + ) return df From 46e27f0745aa0852d25ae076fd4bc40ef54e4975 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:38:44 +0100 Subject: [PATCH 293/658] typo --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c7d68a7..cb52223 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -276,7 +276,7 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): band = spre.bandpass_filter( rec, freq_min=freq_min, - freq_max=freq_min, + freq_max=freq_max, ftype=ftype, ) From 3e0b6834dd4bd113e7fddebf0f36a71672adb939 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:38:50 +0100 Subject: [PATCH 294/658] save report before export to phy --- pixels/pixels_utils.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index cb52223..4991bca 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -568,22 +568,19 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): output_folder=output/"curated_report", ) - # export to phy for additional manual curation if needed - sexp.export_to_phy( - sorting_analyzer=curated_sa, - output_folder=output/"phy", - copy_binary=False, - ) - # save sa to disk curated_sa.save_as( format="zarr", folder=curated_sa_dir, ) - # TODO apr 11 2025: - # after we curated units with quality metrics, there are still some noise in - # units, how do we use si to remove those ones identified qualitatively? + # export to phy for additional manual curation if needed + sexp.export_to_phy( + sorting_analyzer=curated_sa, + output_folder=output/"curated_report/phy", + compute_pc_features=False, + copy_binary=False, + ) return None From 8d65adc1307ab934255399eda2439c01851d62ef Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:39:13 +0100 Subject: [PATCH 295/658] remove df dirs --- pixels/pixels_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4991bca..c2b0c45 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -700,9 +700,8 @@ def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, return None -def _save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, - fr_df_path, sigma, sample_rate, repeats, spiked, - spiked_shape, concat_spiked_path): +def _save_spike_chance(spiked_memmap_path, fr_memmap_path, sigma, sample_rate, + repeats, spiked, spiked_shape, concat_spiked_path): """ Implementation of saving chance level spike data. """ From 376082d7e811c2ba6fd3217816c37c3e4cffaf5c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:39:47 +0100 Subject: [PATCH 296/658] add histology directory --- pixels/experiment.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index 7cf1c87..03282e1 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -50,6 +50,7 @@ def __init__( data_dir, meta_dir=None, interim_dir=None, + hist_dir=None, processed_dir=None, session_date_fmt="%y%m%d", of_date=None, @@ -71,6 +72,9 @@ def __init__( else: self.meta_dir = None + if hist_dir: + self.hist_dir = Path(hist_dir).expanduser() + self.sessions = [] sessions = ioutils.get_sessions( mouse_ids, @@ -89,6 +93,7 @@ def __init__( metadata=[s['metadata'] for s in metadata], data_dir=metadata[0]['data_dir'], interim_dir=interim_dir, + hist_dir=hist_dir, processed_dir=processed_dir, ) ) From 143b257413b82981517cebc633e489f6d0921e7f Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:40:08 +0100 Subject: [PATCH 297/658] add mouse id --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 62ca8b9..2669454 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -203,8 +203,9 @@ class Behaviour(ABC): SAMPLE_RATE = 2000#1000 def __init__(self, name, data_dir, metadata=None, processed_dir=None, - interim_dir=None): + interim_dir=None, hist_dir=None): self.name = name + self.mouse_id = name.split("_")[-1] self.data_dir = data_dir self.metadata = metadata From 05f08ea7e94156dda6f09cb7c9dc88e3ad410055 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:40:17 +0100 Subject: [PATCH 298/658] add histology dir --- pixels/behaviours/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2669454..dc28381 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -221,6 +221,12 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, else: self.processed = Path(processed_dir).expanduser() / self.name + if hist_dir is None: + self.histology = self.data_dir / 'histology'\ + / processed / self.mouse_id + else: + self.histology = Path(hist_dir).expanduser() / self.mouse_id + self.files = ioutils.get_data_files(self.raw, name) self.CatGT_dir = sorted(glob.glob( From 4bc97e22a2b150cf5594f14ca456b629871035cc Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:40:26 +0100 Subject: [PATCH 299/658] use depth info for mouse --- pixels/behaviours/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index dc28381..9bf6a1d 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -717,9 +717,10 @@ def preprocess_raw(self): ) # load brain surface depths - surface_depths = load_yaml( - path=self.find_file(stream_files["surface_depth"]), + depth_info = load_yaml( + path=self.histology / stream_files["depth_info"], ) + surface_depths = depth_info["raw_signal_depths"][stream_id] # preprocess stream_files["preprocessed"] = xut.preprocess_raw( From 34fcc2b836dc865b87ce4afef35472d18da0b8eb Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:40:37 +0100 Subject: [PATCH 300/658] explicitly return None --- pixels/behaviours/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9bf6a1d..143d3c3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -832,6 +832,9 @@ def extract_bands(self, freqs=None): ioutils.write_hdf5(output, data) """ + return None + + def run_catgt(self, CatGT_app=None, args=None) -> None: """ This func performs CatGT on copied AP data in the interim. From c5dafd8a61aff277e9c677574d4cf47c3222aeae Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 24 Apr 2025 15:40:57 +0100 Subject: [PATCH 301/658] remove doc --- pixels/behaviours/base.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 143d3c3..f668825 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1956,16 +1956,6 @@ def _get_aligned_trials( #if not rec_trials_fr[stream]: # return None - # TODO feb 28 2025: - # in this func, scanning period is longer than actual trials to avoid - # edging effect, so spiked is the whole scanning period then convolved - # into spike rate. if however, we get the spike_bool here, it does not - # have the scanning period buffer, the fr might have edging effect. but - # if we take the scanning period, the number of spikes will be - # different from the real data... - # SOLUTION: concat as below, shuffle per column, then convolve per - # column - # concat trial positions positions = ioutils.reindex_by_longest( dfs=trial_positions, From 9ece78b18dbcc6829b3c89079ed391abce5ce653 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 25 Apr 2025 12:06:48 +0100 Subject: [PATCH 302/658] correct group id if not all shanks used --- pixels/pixels_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c2b0c45..43ad915 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -57,12 +57,18 @@ def load_raw(paths, stream_id): def preprocess_raw(rec, surface_depths): - shank_groups = rec.get_channel_groups() - if not np.all(shank_groups == shank_groups[0]): + if np.unique(rec.get_channel_groups()).size < 4: + # correct group id if not all shanks used + group_ids = correct_group_id(rec) + # change the group id + rec.set_channel_groups(group_ids) + + if not np.all(group_ids == group_ids[0]): + # if more than one shank used preprocessed = [] # split by groups groups = rec.split_by("group") - for g, group in enumerate(groups.values()): + for g, group in groups.items(): print(f"> Preprocessing shank {g}") # get brain surface depth of shank surface_depth = surface_depths[g] From adaa515353b054785ffa233aead71d04b27e17c6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 25 Apr 2025 12:07:37 +0100 Subject: [PATCH 303/658] use more informative name; use dict to map shank x locs --- pixels/pixels_utils.py | 49 ++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 43ad915..1e265a9 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1046,7 +1046,7 @@ def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, return bin_data -def get_shank_id_for_single_shank(rec): +def correct_group_id(rec): # check probe type ''' npx 1.0: 0 @@ -1054,27 +1054,30 @@ def get_shank_id_for_single_shank(rec): npx 2.0 commercial: 2013 ''' probe_type = int(rec.get_annotation("probes_info")[0]["probe_type"]) + # double check it is multishank probe + assert probe_type > 0 # get channel x locations - ''' - shank 0: 0, 32 - shank 1: 250, 282 - shank 2: 500, 582 - shank 3: 750, 782 - ''' - if probe_type > 0: - x_locs = np.unique(rec.get_channel_locations()[:, 0]) - if np.all(x_locs < 200): - shank_id = 0 - elif np.all(x_locs > 200) and np.all(x_locs < 500): - shank_id = 1 - elif np.all(x_locs > 500) and np.all(x_locs < 700): - shank_id = 2 - elif np.all(x_locs > 700): - shank_id = 3 - - # get number of channels and set their group to shank id - ids = np.zeros(rec.channel_ids.shape).astype(int) - ids[:] = shank_id - - return ids + shank_x_locs = { + 0: [0, 32], + 1: [250, 282], + 2: [500, 582], + 3: [750, 782], + } + + # get group ids + group_ids = rec.get_channel_groups() + + x_locs = rec.get_channel_locations()[:, 0] + for shank_id, shank_x in shank_x_locs.items(): + # map bool channel x locations + shank_bool = np.isin(x_locs, shank_x) + if np.any(shank_bool) == False: + print(f"\n> Recording does not have shank {shank_id}, continue.") + continue + group_ids[shank_bool] = shank_id + + print("\n> Not all shanks used in multishank probe, change group ids into " + f"{np.unique(group_ids)}.") + + return group_ids From 3d5ccb01741154130350d184d0167710d3020fc9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 25 Apr 2025 12:17:02 +0100 Subject: [PATCH 304/658] make sure to use dict key as group id, not index --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1e265a9..457e5a7 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -198,7 +198,7 @@ def detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): # split by groups groups = rec.split_by("group") dfs = [] - for g, group in enumerate(groups.values()): + for g, group in groups.items(): print(f"\n> Estimate drift of shank {g}") dfs.append(_detect_n_localise_peaks(group, loc_method)) # concat shanks From 08cd9f86b478b8ffd9e1d57d05a3267d406fd682 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Apr 2025 13:22:16 +0100 Subject: [PATCH 305/658] use logging not print --- pixels/behaviours/base.py | 144 ++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 70 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f668825..cdc2aea 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -100,7 +100,7 @@ def func(*args, **kwargs): # load cache try: df = ioutils.read_hdf5(output) - print(f"> Cache loaded from {output}.") + logging.info(f"> Cache loaded from {output}.") except HDF5ExtError: df = None except (KeyError, ValueError): @@ -130,7 +130,7 @@ def func(*args, **kwargs): if stream not in df: df[stream] = {} df[stream][name] = store[key] - print(f"> Cache loaded from {output}.") + logging.info(f"> Cache loaded from {output}.") else: df = method(*args, **kwargs) output.parent.mkdir(parents=True, exist_ok=True) @@ -200,7 +200,7 @@ class Behaviour(ABC): """ - SAMPLE_RATE = 2000#1000 + SAMPLE_RATE = SAMPLE_RATE def __init__(self, name, data_dir, metadata=None, processed_dir=None, interim_dir=None, hist_dir=None): @@ -361,7 +361,7 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: raw = self.raw / name if raw.exists(): if copy: - print(f" {self.name}: Copying {name} to interim") + logging.info(f" {self.name}: Copying {name} to interim") copyfile(raw, interim) return interim return raw @@ -369,11 +369,11 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: tar = raw.with_name(raw.name + '.tar.gz') if tar.exists(): if copy: - print(f" {self.name}: Extracting {tar.name} to interim") + logging.info(f" {self.name}: Extracting {tar.name} to interim") with tarfile.open(tar) as open_tar: open_tar.extractall(path=self.interim) return interim - print(f" {self.name}: Extracting {tar.name}") + logging.info(f" {self.name}: Extracting {tar.name}") with tarfile.open(tar) as open_tar: open_tar.extractall(path=self.raw) return raw @@ -404,11 +404,11 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): # TODO jan 14 2025: # this func is not used in vr behaviour, since they are synched # in vd.session - print(" Finding lag between sync channels") + logging.info(" Finding lag between sync channels") recording = self.files[rec_num] if behavioural_data is None: - print(" Loading behavioural data") + logging.info(" Loading behavioural data") data_file = self.find_file(recording['behaviour']) behavioural_data = ioutils.read_tdms(data_file, groups=["NpxlSync_Signal"]) behavioural_data = signal.resample( @@ -416,7 +416,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): ) if sync_channel is None: - print(" Loading neuropixels sync channel") + logging.info(" Loading neuropixels sync channel") data_file = self.find_file(recording['lfp_data']) num_chans = self.lfp_meta[rec_num]['nSavedChans'] sync_channel = ioutils.read_bin(data_file, num_chans, channel=384) @@ -427,7 +427,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): behavioural_data = signal.binarise(behavioural_data) sync_channel = signal.binarise(sync_channel) - print(" Finding lag") + logging.info(" Finding lag") plot = self.processed / f'sync_{rec_num}.png' lag_start, match = signal.find_sync_lag( behavioural_data, sync_channel, plot=plot, @@ -437,8 +437,9 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): self._lag[rec_num] = (lag_start, lag_end) if match < 95: - print(" The sync channels did not match very well. Check the plot.") - print(f" Calculated lag: {(lag_start, lag_end)}") + logging.warning(" The sync channels did not match very well. " + "Check the plot.") + logging.info(f" Calculated lag: {(lag_start, lag_end)}") lag_json = [] for lag in self._lag: @@ -479,8 +480,8 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): # do not redo the remapping if not necessary if output.exists(): - print(f'\n> Spike times from {self.ks_outputs[remap_stream_idx]}\ - already remapped, next session.') + logging.info("\n> Spike times from " + f"{self.ks_outputs[remap_stream_idx]} already remapped, next session.") cluster_times = self.get_spike_times()[remap_stream_idx] remapped_cluster_times = self.get_spike_times( remapped=True)[remap_stream_idx] @@ -544,7 +545,7 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): times_ms = times * self.SAMPLE_RATE / orig_rate lag = [None, 'later', 'earlier'] - print(f"""\n> {stream_ids[0]} started\r + logging.info(f"""\n> {stream_ids[0]} started\r {abs(initial_dt[0]*1000):.2f}ms {lag[int(np.sign(initial_dt[0]))]}\r and finished\r {abs(initial_dt[-1]*1000):.2f}ms {lag[-int(np.sign(initial_dt[0]))]}\r @@ -565,7 +566,7 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): remapped_times_ms[t] = ((time - edges[remap_stream_idx][bin_idx]) * scales[bin_idx]) + edges[0][bin_idx] - print(f"""\n> Remap stats {stream_ids[remap_stream_idx]} spike times:\r + logging.info(f"""\n> Remap stats {stream_ids[remap_stream_idx]} spike times:\r median shift {np.median(remapped_times_ms-times_ms):.2f}ms,\r min shift {np.min(remapped_times_ms-times_ms):.2f}ms,\r max shift {np.max(remapped_times_ms-times_ms):.2f}ms.""") @@ -573,7 +574,7 @@ def sync_streams(self, SYNC_BIN, remap_stream_idx): # convert remappmed times back to its original sample index remapped_times = np.uint64(remapped_times_ms * orig_rate / self.SAMPLE_RATE) np.save(output, remapped_times) - print(f'\n> Spike times remapping output saved to\n {output}.') + logging.info(f'\n> Spike times remapping output saved to\n {output}.') # load remapped spike times of each cluster cluster_times = self.get_spike_times()[remap_stream_idx] @@ -595,11 +596,11 @@ def process_behaviour(self): # NOTE jan 14 2025: # this func is not used by vr behaviour for rec_num, recording in enumerate(self.files): - print( + logging.info( f">>>>> Processing behaviour for recording {rec_num + 1} of {len(self.files)}" ) - print(f"> Loading behavioural data") + logging.info(f"> Loading behavioural data") behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour'])) # ignore any columns that have Nans; these just contain settings @@ -607,12 +608,12 @@ def process_behaviour(self): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - print(f"> Downsampling to {self.SAMPLE_RATE} Hz") + logging.info(f"> Downsampling to {self.SAMPLE_RATE} Hz") behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] - print(f"> Syncing to Neuropixels data") + logging.info(f"> Syncing to Neuropixels data") if self._lag[rec_num] is None: self.sync_data( rec_num, @@ -622,20 +623,20 @@ def process_behaviour(self): behavioural_data = behavioural_data[max(lag_start, 0):-1-max(lag_end, 0)] behavioural_data.index = range(len(behavioural_data)) - print(f"> Extracting action labels") + logging.info(f"> Extracting action labels") self._action_labels[rec_num] = self._extract_action_labels(rec_num, behavioural_data) output = self.processed / recording['action_labels'] np.save(output, self._action_labels[rec_num]) - print(f"> Saved to: {output}") + logging.info(f"> Saved to: {output}") output = self.processed / recording['behaviour_processed'] - print(f"> Saving downsampled behavioural data to:") - print(f" {output}") + logging.info(f"> Saving downsampled behavioural data to:") + logging.info(f" {output}") behavioural_data.drop("/'NpxlSync_Signal'/'0'", axis=1, inplace=True) ioutils.write_hdf5(output, behavioural_data) self._behavioural_data[rec_num] = behavioural_data - print("> Done!") + logging.info("> Done!") def correct_motion(self, mc_method="dredge"): """ @@ -653,7 +654,7 @@ def correct_motion(self, mc_method="dredge"): None """ if mc_method == "ks": - print(f"> Correct motion later with {mc_method}.") + logging.info(f"> Correct motion later with {mc_method}.") return None # get pixels streams @@ -662,7 +663,7 @@ def correct_motion(self, mc_method="dredge"): for stream_id, stream_files in streams.items(): output = self.interim / stream_files["motion_corrected"] if output.exists(): - print(f"> Motion corrected {stream_id} loaded.") + logging.info(f"> Motion corrected {stream_id} loaded.") continue # preprocess raw recording @@ -671,7 +672,7 @@ def correct_motion(self, mc_method="dredge"): # load preprocessed rec rec = stream_files["preprocessed"] - print( + logging.info( f"\n>>>>> Correcting motion for recording from {stream_id} " f"in total of {self.stream_count} stream(s) with {mc_method}" ) @@ -711,7 +712,7 @@ def preprocess_raw(self): for stream_id, stream_files in streams.items(): # load raw si rec rec = stream_files["si_rec"] - print( + logging.info( f"\n>>>>> Preprocessing data for recording from {stream_id} " f"in total of {self.stream_count} stream(s)" ) @@ -757,7 +758,7 @@ def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): for stream_id, stream_files in streams.items(): output = self.processed / stream_files["detected_peaks"] if output.exists(): - print(f"> Peaks from {stream_id} already detected.") + logging.info(f"> Peaks from {stream_id} already detected.") continue # get ap band @@ -790,13 +791,13 @@ def extract_bands(self, freqs=None): for name, freqs in bands.items(): output = self.processed / stream_files[f"{name}_extracted"] if output.exists(): - print(f"> {name} bands from {stream_id} loaded.") + logging.info(f"> {name} bands from {stream_id} loaded.") continue # preprocess raw data self.preprocess_raw() - print( + logging.info( f">>>>> Extracting {name} bands from {self.name} " f"{stream_id} in total of {self.stream_count} stream(s)" ) @@ -823,7 +824,7 @@ def extract_bands(self, freqs=None): self.sync_data(rec_num, sync_channel=data[:, -1]) lag_start, lag_end = self._lag[rec_num] - print(f"> Saving data to {output}") + logging.info(f"> Saving data to {output}") if lag_end < 0: data = data[:lag_end] if lag_start < 0: @@ -866,11 +867,12 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: if (isinstance(self.CatGT_dir, list) and len(self.CatGT_dir) != 0 and len(os.listdir(self.CatGT_dir[0])) != 0): - print(f"\nCatGT already performed on ap data of {self.name}. Next session.\n") + logging.info(f"\nCatGT already performed on ap data of {self.name}." + " Next session.\n") return else: #TODO: finish this here so that catgt can run together with sorting - print(f"\n> Running CatGT on ap data of {self.name}") + logging.info(f"\n> Running CatGT on ap data of {self.name}") #_dir = self.interim if args == None: @@ -896,7 +898,7 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: -xid=2,0,384,6,20,15" session_args = f"-dir={self.interim} -run={self.name} -dest={self.interim} " + args - print(f"\ncatgt args of {self.name}: \n{session_args}") + logging.info(f"\ncatgt args of {self.name}: \n{session_args}") subprocess.run( ['./run_catgt.sh', session_args]) @@ -960,9 +962,9 @@ def sort_spikes(self, mc_method="dredge"): # check if already sorted and exported sa_dir = self.processed / stream_files["sorting_analyser"] if not sa_dir.exists(): - print(f"> {self.name} {stream_id} not sorted/exported.\n") + logging.info(f"> {self.name} {stream_id} not sorted/exported.\n") else: - print("> Already sorted and exported, next session.\n") + logging.info("> Already sorted and exported, next session.\n") continue # get catgt directory @@ -1038,7 +1040,7 @@ def configure_motion_tracking(self, project: str) -> None: videos.append(self.interim / video.with_suffix('.avi')) if not videos: - print(self.name, ": No matching videos for project:", project) + logging.info(self.name, ": No matching videos for project:", project) return if config: @@ -1048,7 +1050,7 @@ def configure_motion_tracking(self, project: str) -> None: copy_videos=False, ) else: - print(f"Config not found.") + logging.warning(f"Config not found.") reply = input("Create new project? [Y/n]") if reply and reply[0].lower() == "n": raise PixelsError("A DLC project is needed for motion tracking.") @@ -1654,7 +1656,7 @@ def get_spike_times(self, units, remapped=False, use_si=False): else: if remapped and stream_num > 0: times = self.ks_outputs[stream_num] / f'spike_times_remapped.npy' - print(f"""\n> Found remapped spike times from\r + logging.info(f"""\n> Found remapped spike times from\r {self.ks_outputs[stream_num]}, try to load this.""") else: times = self.ks_outputs[stream_num] / f'spike_times.npy' @@ -1680,7 +1682,8 @@ def get_spike_times(self, units, remapped=False, use_si=False): ) repeats = c_times[np.where(counts>1)] if len(repeats>1): - print(f"> removed {len(repeats)} double-counted spikes from cluster {c}.") + logging.info(f"> removed {len(repeats)} double-counted " + "spikes from cluster {c}.") by_clust[c] = pd.Series(uniques) spike_times[stream_num] = pd.concat(by_clust, axis=1, names=['unit']) @@ -1745,7 +1748,8 @@ def _get_aligned_spike_times( if len(centre) == 0: # See comment in align_trials as to why we just continue instead of # erroring like we used to here. - print("No event found for an action. If this is OK, ignore this.") + logging.info("No event found for an action. If this is OK, " + "ignore this.") continue centre = start + centre[0] @@ -1844,9 +1848,10 @@ def _get_aligned_trials( selected_ends = trials[np.where(np.isin(trials, ends))[0]] end_t = timestamps[selected_ends] if selected_starts.size == 0: - print(f"> No trials found with label {label} and event {event}, " + logging.info(f"> No trials found with label {label} and event {event}, " "output will be empty.") for key in streams.keys(): + assert 0 output[key[:-3]] = {} return output @@ -2297,7 +2302,7 @@ def align_trials( raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") if data in ("spike_times", "spike_rate"): - print(f"Aligning {data} to trials.") + logging.info(f"Aligning {data} to trials.") # we let a dedicated function handle aligning spike times return self._get_aligned_spike_times( label, event, duration, rate=data == "spike_rate", sigma=sigma, @@ -2305,7 +2310,7 @@ def align_trials( ) if "trial" in data: - print(f"Aligning {data} of {units} units to trials.") + logging.info(f"Aligning {data} of {units} units to trials.") return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, end_event=end_event, @@ -2317,14 +2322,14 @@ def align_trials( action_labels = self.get_action_labels() if raw: - print(f"Aligning raw {data} data to trials.") + logging.info(f"Aligning raw {data} data to trials.") getter = getattr(self, f"get_{data}_data_raw", None) if not getter: raise PixelsError(f"align_trials: {data} doesn't have a 'raw' option.") values, SAMPLE_RATE = getter() else: - print(f"Aligning {data} data to trials.") + logging.info(f"Aligning {data} data to trials.") if dlc_project: values = self.get_motion_tracking_data(dlc_project) elif data == "motion_index": @@ -2370,7 +2375,7 @@ def align_trials( # here to warn the user in case it is an error, while otherwise # continuing. #raise PixelsError('Action labels probably miscalculated') - print("No event found for an action. If this is OK, ignore this.") + logging.info("No event found for an action. If this is OK, ignore this.") continue centre = start + centre[0] centre = int(centre * SAMPLE_RATE / self.SAMPLE_RATE) @@ -2465,7 +2470,7 @@ def align_clips(self, label, event, video_match, duration=1): for start in trial_starts: centre = np.where(np.bitwise_and(events[start:start + scan_duration], event))[0] if len(centre) == 0: - print("No event found for an action. If this is OK, ignore this.") + logging.info("No event found for an action. If this is OK, ignore this.") continue centre = start + centre[0] frames = timings.loc[ @@ -2509,7 +2514,7 @@ def get_good_units_info(self): if self._good_unit_info is None: #az: good_units_info.tsv saved while running depth_profile.py info_file = self.interim / 'good_units_info.tsv' - #print(f"> got good unit info at {info_file}\n") + #logging.info(f"> got good unit info at {info_file}\n") try: info = pd.read_csv(info_file, sep='\t') @@ -2527,7 +2532,7 @@ def get_spike_widths(self, units=None): all_widths = self.get_spike_widths() return all_widths.loc[all_widths.unit.isin(units)] - print("Calculating spike widths") + logging.info("Calculating spike widths") waveforms = self.get_spike_waveforms() widths = [] @@ -2580,7 +2585,7 @@ def get_spike_waveforms(self, units=None, method='phy'): rec_forms = {} for u, unit in enumerate(units): - print(f"{round(100 * u / len(units), 2)}% complete") + logging.info(f"{round(100 * u / len(units), 2)}% complete") # get the waveforms from only the best channel spike_ids = model.get_cluster_spikes(unit) best_chan = model.get_cluster_channels(unit)[0] @@ -2628,13 +2633,13 @@ def get_spike_waveforms(self, units=None, method='phy'): ks_mod_time = os.path.getmtime(self.ks_outputs / 'cluster_info.tsv') assert template_cache_mod_time < ks_mod_time check = True # re-extract waveforms - print("> Re-extracting waveforms since kilosort output is newer.") + logging.info("> Re-extracting waveforms since kilosort output is newer.") except: if 'template_cache_mod_time' in locals(): - print("> Loading existing waveforms.") + logging.info("> Loading existing waveforms.") check = False # load existing waveforms else: - print("> Extracting waveforms since they are not extracted.") + logging.info("> Extracting waveforms since they are not extracted.") check = True # re-extract waveforms """ @@ -2647,7 +2652,7 @@ def get_spike_waveforms(self, units=None, method='phy'): test.annotate(is_filtered=True) # check all annotations test.get_annotation('is_filtered') - print(test) + logging.info(test) """ # extract waveforms @@ -2697,7 +2702,7 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): # normalise these metrics before passing to k-means columns = ["unit", "duration", "trough_peak_ratio", "half_width", "repolarisation_slope", "recovery_slope"] - print(f"> Calculating waveform metrics {columns[1:]}...\n") + logging.info(f"> Calculating waveform metrics {columns[1:]}...\n") waveforms = self.get_spike_waveforms() # remove nan values @@ -2747,7 +2752,7 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): # repolarisation slope returns = np.where(mean_waveform.iloc[trough_idx:] >= 0) + trough_idx if len(returns[0]) == 0: - print(f"> The mean waveformrns never returned to baseline?\n") + logging.info(f"> The mean waveformrns never returned to baseline?\n") return_idx = mean_waveform.shape[0] - 1 else: return_idx = returns[0][0] @@ -2887,7 +2892,7 @@ def get_aligned_spike_rate_CI( trial_responses = [] for trial, t_start, t_end in zip(trials, start, end): if not (t_start < t_end): - print( + logging.warning( f"Warning: trial {trial} skipped in CI calculation due to bad timepoints" ) continue @@ -3130,13 +3135,12 @@ def bin_aligned_trials( if output_count_path.exists(): continue - key = f"{stream_id[:-3]}/" - print(f"\n> Binning trials from {stream_id}.") + logging.info(f"\n> Binning trials from {stream_id}.") # get stream spiked stream_spiked = spiked[stream]["spiked"] if stream_spiked.size == 0: - print(f"\n> No units found in {units}, continue.") + logging.info(f"\n> No units found in {units}, continue.") return None # get stream positions @@ -3188,14 +3192,14 @@ def bin_aligned_trials( "pos": bin_fr_arr[:, -2:, :], } np.savez_compressed(output_fr_path, **fr_to_save) - print(f"> Output saved at {output_fr_path}.") + logging.info(f"> Output saved at {output_fr_path}.") count_to_save = { "count": bin_count_arr[:, :-2, :], "pos": bin_count_arr[:, -2:, :], } np.savez_compressed(output_count_path, **count_to_save) - print(f"> Output saved at {output_count_path}.") + logging.info(f"> Output saved at {output_count_path}.") return None @@ -3219,11 +3223,11 @@ def get_chance_data(self, time_bin, pos_bin): # TODO apr 3 2025: how the fuck to get positions here???? # TEMP: get it manually... # light - pos_path = self.interim /\ - "cache/align_trials_all_trial_times_725_1_100_512.h5" - # dark #pos_path = self.interim /\ - # "cache/align_trials_all_trial_times_1322_1_100_512.h5" + # "cache/align_trials_all_trial_times_725_1_100_512.h5" + # dark + pos_path = self.interim /\ + "cache/align_trials_all_trial_times_1322_1_100_512.h5" with pd.HDFStore(pos_path, "r") as store: # list all keys From 757cbac0e3f50b2d7633fda4b9afa81a457086c4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Apr 2025 13:23:49 +0100 Subject: [PATCH 306/658] add configs; use logging --- pixels/behaviours/virtual_reality.py | 9 +++++---- pixels/configs.py | 30 ++++++++++++++++++++++++++++ pixels/pixels_utils.py | 1 + 3 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 pixels/configs.py diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 7bf3fae..99ebcbc 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -21,6 +21,7 @@ import pixels.signal_utils as signal from pixels import ioutils from pixels.behaviours import Behaviour +from pixels.configs import * from common_utils import file_utils @@ -144,7 +145,7 @@ def _extract_action_labels(self, vr, vr_data): trial_dark = (vr_data.trial_type == Conditions.DARK) # <<<< definitions <<<< - print(">> Mapping vr event times...") + logging.info(">> Mapping vr event times...") # >>>> gray >>>> # get timestamps of gray @@ -262,7 +263,7 @@ def _extract_action_labels(self, vr, vr_data): # TODO jun 27 2024 positional events and valve events needs mapping - print(">> Mapping vr action times...") + logging.info(">> Mapping vr action times...") # >>>> map reward types >>>> # get non-zero reward types @@ -298,8 +299,8 @@ def _extract_action_labels(self, vr, vr_data): assert (trial == vr_data.trial_count.unique().max()) assert (vr_data[of_trial].position_in_tunnel.max()\ < vr.tunnel_reset) - print(f"> trial {trial} is unfinished when session ends, so " - "there is no outcome.") + logging.info(f"> trial {trial} is unfinished when session ends, " + "so there is no outcome.") # <<<< unfinished trial <<<< else: # >>>> non punished >>>> diff --git a/pixels/configs.py b/pixels/configs.py new file mode 100644 index 0000000..6f00399 --- /dev/null +++ b/pixels/configs.py @@ -0,0 +1,30 @@ +import logging + +from wavpack_numcodecs import WavPack +import spikeinterface as si + +# Configure logging to include a timestamp with seconds +logging.basicConfig( + level=logging.INFO, + format='''\n%(asctime)s %(levelname)s: %(message)s\ + \n[in %(filename)s:%(lineno)d]''', + datefmt='%Y%m%d %H:%M:%S', +) + +#logging.info('This is an info message.') +#logging.warning('This is a warning message.') +#logging.error('This is an error message.') + +# set si job_kwargs +job_kwargs = dict( + n_jobs=0.8, # 80% core + chunk_duration='1s', + progress_bar=True, +) +si.set_global_job_kwargs(**job_kwargs) + +# instantiate WavPack compressor +wv_compressor = WavPack( + level=3, # high compression + bps=None, # lossless +) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 457e5a7..839bdca 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -19,6 +19,7 @@ import pixels.signal_utils as signal from pixels.ioutils import write_hdf5, reindex_by_longest from pixels.error import PixelsError +from pixels.configs import * from common_utils.math_utils import random_sampling from common_utils.file_utils import init_memmap, read_hdf5 From 2b104c71e25ced697042c58676a09b9445f91eb6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Apr 2025 13:24:09 +0100 Subject: [PATCH 307/658] define behaviour rate --- pixels/constants.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pixels/constants.py b/pixels/constants.py index f5127b6..8f52d89 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -1,7 +1,15 @@ """ This file contains some constants parameters for the pixels pipeline. """ +import numpy as np + +SAMPLE_RATE = 2000 + freq_bands = { "ap":[300, 9000], "lfp":[0.5, 300], } + +BEHAVIOUR_HZ = 25000 + +np.random.seed(BEHAVIOUR_HZ) From 113a649f628ecb0a6edd2820a85ba0ccf18d4203 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Apr 2025 14:15:29 +0100 Subject: [PATCH 308/658] add underscore --- pixels/ioutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index fb73f41..c93c315 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -194,7 +194,7 @@ def get_data_files(data_dir, session_name): session_name + "_pupil_processed.h5" )) behaviour["motion_index"] = base_name.with_name( - session_name + "motion_index.npz" + session_name + "_motion_index.npz" ) behaviour["motion_tracking"] = base_name.with_name( session_name + "_motion_tracking.h5" From c91584df7ea383e7c433378c56c16ec5e7e93664 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Apr 2025 14:19:43 +0100 Subject: [PATCH 309/658] correct group id if not all shanks r used; just get the unique id here --- pixels/pixels_utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 839bdca..299f052 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -78,13 +78,8 @@ def preprocess_raw(rec, surface_depths): # aggregate groups together preprocessed = si.aggregate_channels(preprocessed) else: - # check which shank used - group_id = get_shank_id_for_single_shank(rec) - unique_id = np.unique(group_id)[0] - print("\n> Single shank used in multishank probe, change group id into " - f"{unique_id}.") - # change the group id - rec.set_channel_groups(group_id) + # if only one shank used, check which shank + unique_id = np.unique(group_ids)[0] # get brain surface depth of shank surface_depth = surface_depths[unique_id] # preprocess From 1d598d4a32aac542e3ff32fb8eb2a1f75ac49497 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Apr 2025 14:20:10 +0100 Subject: [PATCH 310/658] add stream class --- pixels/stream.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 pixels/stream.py diff --git a/pixels/stream.py b/pixels/stream.py new file mode 100644 index 0000000..df570a8 --- /dev/null +++ b/pixels/stream.py @@ -0,0 +1,157 @@ +import numpy as np + +from pixels import ioutils +from pixels import pixels_utils as xut +import pixels.signal_utils as signal +from pixels.configs import * + +from common_utils import file_utils + +class Stream: + def __init__( + self, + stream_id, + stream_num, + files, + session, + ): + self.stream_id = stream_id + self.stream_num = stream_num + self.probe_id = stream_id[:-3] + self.files = files + self.session = session + self.behaviour_files = session.files["behaviour"] + self.SAMPLE_RATE = session.SAMPLE_RATE + + def __repr__(self): + return f"" + + + def sync_vr(self, vr): + # get action labels + action_labels = self.session.get_action_labels()[self.stream_num] + if action_labels: + logging.info(f"\n> {self.stream_id} from {self.session.name} is " + "already synched with vr, continue.") + else: + _sync_vr(self, vr) + + return None + + + def _sync_vr(self, vr): + # get spike data + spike_data = self.session.find_file( + name=self.files["ap_raw"][self.stream_num], + copy=True, + ) + + # get vr data + vr_session = vr.sessions[0] + synched_vr_path = vr_session.cache_dir + "synched/" +\ + vr_session.name + "_vr_synched.h5" + + try: + synched_vr = file_utils.read_hdf5(synched_vr_path) + logging.info("> synchronised vr loaded") + except: + # get sync pulses + sync_map = ioutils.read_bin(spike_data, 385, 384) + syncs = signal.binarise(sync_map) + + # >>>> resample pixels sync pulse to sample rate >>>> + # get ap data sampling rate + spike_samp_rate = int(self.session.ap_meta[0]['imSampRate']) + # downsample pixels sync pulse + downsampled = signal.decimate( + array=syncs, + from_hz=spike_samp_rate, + to_hz=self.SAMPLE_RATE, + ) + # binarise to avoid non integers + pixels_syncs = signal.binarise(downsampled) + # <<<< resample pixels sync pulse to 1kHz <<<< + + # TODO apr 11 2025: + # for 20250723 VDCN09, sync pulses are weird. check number of syncs + # from 145s onwards, see if it matches with vr frames from 100s + # onwards + + # get the rise and fall edges in pixels sync + pixels_edges = np.where(np.diff(pixels_syncs) != 0)[0] + 1 + # double check if the pulse from arduino initiation is also + # included, if so there will be two long pulses before vr frames + first_pulses = np.diff(pixels_edges)[:4] + if (first_pulses > 1000).all(): + logging.info("\n> There are two long pulses before vr frames, " + "remove both.") + remove = 4 + else: + remove = 2 + pixels_vr_edges = pixels_edges[remove:] + # convert value into their index to calculate all timestamps + pixels_idx = np.arange(pixels_syncs.shape[0]) + + synched_vr = vr.sync_streams( + self.SAMPLE_RATE, + pixels_vr_edges, + pixels_idx, + )[vr_session.name] + + # save to pixels processed dir + file_utils.write_hdf5( + self.session.processed /\ + self.behaviour_files['vr_synched'][self.stream_num], + synched_vr, + ) + + # get action label dir + action_labels_path = self.session.processed /\ + self.behaviour_files["action_labels"][self.stream_num] + + # extract and save action labels + action_labels = self.session._extract_action_labels( + vr_session, + synched_vr, + ) + np.savez_compressed( + action_labels_path, + outcome=action_labels[:, 0], + events=action_labels[:, 1], + timestamps=action_labels[:, 2], + ) + logging.info(f"> Action labels saved to: {action_labels_path}.") + + return None + + + def save_spike_chance(self, spiked, sigma): + # TODO apr 21 2025: + # do we put this func here or in stream.py?? + # save index and columns to reconstruct df for shuffled data + assert 0 + ioutils.save_index_to_frame( + df=spiked, + path=shuffled_idx_path, + ) + ioutils.save_cols_to_frame( + df=spiked, + path=shuffled_col_path, + ) + + # get chance data paths + paths = { + "spiked_memmap_path": self.interim /\ + stream_files["spiked_shuffled_memmap"], + "fr_memmap_path": self.interim / stream_files["fr_shuffled_memmap"], + } + + # save chance data + xut.save_spike_chance( + **paths, + sigma=sigma, + sample_rate=self.SAMPLE_RATE, + spiked=spiked, + ) + + return None From 29a712bd41850ccb9d912e1198b1dccf01106bf8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:26:02 +0100 Subject: [PATCH 311/658] move configs to configs.py; move SelectedUnits class to units.py; move cacheable to decorators.py --- pixels/behaviours/base.py | 136 +------------------------------------- 1 file changed, 3 insertions(+), 133 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index cdc2aea..a0fa378 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -33,151 +33,21 @@ import spikeinterface.qualitymetrics as sqm from scipy import interpolate from tables import HDF5ExtError -from wavpack_numcodecs import WavPack from pixels import ioutils import pixels.signal_utils as signal import pixels.pixels_utils as xut from pixels.error import PixelsError from pixels.constants import * +from pixels.configs import * +from pixels.stream import Stream +from pixels.units import SelectedUnits from common_utils.file_utils import load_yaml if TYPE_CHECKING: from typing import Optional, Literal -BEHAVIOUR_HZ = 25000 - -np.random.seed(BEHAVIOUR_HZ) - -# set si job_kwargs -job_kwargs = dict( - n_jobs=0.8, # 80% core - chunk_duration='1s', - progress_bar=True, -) - -si.set_global_job_kwargs(**job_kwargs) - -# instantiate WavPack compressor -wv_compressor = WavPack( - level=3, # high compression - bps=None, # lossless -) - -def _cacheable(method): - """ - Methods with this decorator will have their output cached to disk so that future - calls with the same set of arguments will simply load the result from disk. However, - if the key word argument list contains `units` and it is not either `None` or an - instance of `SelectedUnits` then this is disabled. - """ - def func(*args, **kwargs): - name = kwargs.pop("name", None) - - if "units" in kwargs: - units = kwargs["units"] - if not isinstance(units, SelectedUnits) or not hasattr(units, "name"): - return method(*args, **kwargs) - - self, *as_list = list(args) + list(kwargs.values()) - if not self._use_cache: - return method(*args, **kwargs) - - arrays = [i for i, arg in enumerate(as_list) if isinstance(arg, np.ndarray)] - if arrays: - if name is None: - raise PixelsError( - 'Cacheing methods when passing arrays requires also passing name="something"' - ) - for i in arrays: - as_list[i] = name - - as_list.insert(0, method.__name__) - - output = self.interim / 'cache' / ('_'.join(str(i) for i in as_list) + '.h5') - if output.exists() and self._use_cache != "overwrite": - # load cache - try: - df = ioutils.read_hdf5(output) - logging.info(f"> Cache loaded from {output}.") - except HDF5ExtError: - df = None - except (KeyError, ValueError): - # if key="df" is not found, then use HDFStore to list and read - # all dfs - # create df as a dictionary to hold all dfs - df = {} - with pd.HDFStore(output, "r") as store: - # list all keys - keys = store.keys() - # TODO apr 2 2025: for now the nested dict have keys in the - # format of `/imec0/positions`, this will not be the case - # once i flatten files at the stream level rather than - # session level, i.e., every pixels related cache will have - # stream id in their name. - for key in keys: - # remove "/" in key and split - key_name = key.lstrip("/").split("/") - if len(key_name) == 1: - # use the only key name as dict key - df[key_name[0]] = store[key] - elif len(key_name) == 2: - # stream id is the first - stream = key_name[0] - # data name is the second - name = "/".join(key_name[1:]) - if stream not in df: - df[stream] = {} - df[stream][name] = store[key] - logging.info(f"> Cache loaded from {output}.") - else: - df = method(*args, **kwargs) - output.parent.mkdir(parents=True, exist_ok=True) - if df is None: - output.touch() - else: - # allows to save multiple dfs in a dict in one hdf5 file - if ioutils.is_nested_dict(df): - for stream_id, nested_dict in df.items(): - # NOTE: we remove `.ap` in stream id cuz having `.`in - # the key name get problems - for name, values in nested_dict.items(): - key = f"/{stream_id}/{name}" - ioutils.write_hdf5( - path=output, - df=values, - key=key, - mode="a", - ) - elif isinstance(df, dict): - for name, values in df.items(): - ioutils.write_hdf5( - path=output, - df=values, - key=name, - mode="a", - ) - else: - ioutils.write_hdf5(output, df) - return df - return func - - -class SelectedUnits(list): - name: str - """ - This is the return type of Behaviour.select_units, which is a list in every way - except that when represented as a string, it can return a name, if a `name` - attribute has been set on it. This allows methods that have had `units` passed to be - cached to file. - """ - def __repr__(self): - if hasattr(self, "name"): - return self.name - return list.__repr__(self) - - class Behaviour(ABC): """ This class represents a single individual recording session. From a8bcfdfe2a40acf13ab7aa9eb7f1dd3f0d4d97e3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:28:02 +0100 Subject: [PATCH 312/658] put stream-level implementatin in stream.py --- pixels/behaviours/base.py | 428 +++++++++----------------------------- 1 file changed, 94 insertions(+), 334 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a0fa378..afa96ff 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -786,12 +786,17 @@ def load_raw_ap(self): # if multiple runs for the same probe, concatenate them streams = self.files["pixels"] - for stream_id, stream_files in streams.items(): - # get paths of raw - paths = [self.find_file(path) for path in stream_files["ap_raw"]] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + raw_rec = stream.load_raw_ap() # now the value for streams dict is recording extractor - stream_files["si_rec"] = xut.load_raw(paths, stream_id) + stream_files["si_rec"] = raw_rec return None @@ -1675,240 +1680,25 @@ def _get_aligned_trials( trials * units * temporal bins (100ms) """ - action_labels = self.get_action_labels()[0] streams = self.files["pixels"] output = {} - if units is None: - units = self.select_units() - - #if not pos_bin is None: - behaviour_files = self.files["behaviour"] - # assume only one vr session for now - vr_dir = self.find_file(behaviour_files["vr_synched"][0]) - vr_data = ioutils.read_hdf5(vr_dir) - # get positions - positions = vr_data.position_in_tunnel - - #TODO: with multiple streams, spike times will be a list with multiple dfs, - #make sure old code does not break! - spikes = self.get_spike_times(units, use_si=True) - # drop rows if all nans - spikes = spikes.dropna(how="all") - - # since each session has one behaviour session, now only one action - # label file - actions = action_labels["outcome"] - events = action_labels["events"] - # get timestamps index of behaviour in self.SAMPLE_RATE hz, to convert - # it to ms, do timestamps*1000/self.SAMPLE_RATE - timestamps = action_labels["timestamps"] - - # select frames of wanted trial type - trials = np.where(np.bitwise_and(actions, label))[0] - # map starts by event - starts = np.where(np.bitwise_and(events, event))[0] - # map starts by end event - ends = np.where(np.bitwise_and(events, end_event))[0] - - # only take starts from selected trials - selected_starts = trials[np.where(np.isin(trials, starts))[0]] - start_t = timestamps[selected_starts] - # only take ends from selected trials - selected_ends = trials[np.where(np.isin(trials, ends))[0]] - end_t = timestamps[selected_ends] - if selected_starts.size == 0: - logging.info(f"> No trials found with label {label} and event {event}, " - "output will be empty.") - for key in streams.keys(): - assert 0 - output[key[:-3]] = {} - return output - - # use original trial id as trial index - trial_ids = vr_data.iloc[selected_starts].trial_count.unique() - - # pad ends with 1 second extra to remove edge effects from convolution - scan_pad = self.SAMPLE_RATE - scan_starts = start_t - scan_pad - scan_ends = end_t + scan_pad - scan_durations = scan_ends - scan_starts - - cursor = 0 # In sample points - rec_trials_fr = {} - rec_trials_spiked = {} - trial_positions = {} - for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - stream = stream_id[:-3] - # allows multiple streams of recording, i.e., multiple probes - rec_trials_fr[stream] = {} - rec_trials_spiked[stream] = {} - - # Account for multiple raw data files - meta = self.ap_meta[stream_num] - samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 - assert samples.is_integer() - in_SAMPLE_RATE_scale = (samples * self.SAMPLE_RATE)\ - / int(self.ap_meta[0]['imSampRate']) - cursor_duration = (cursor * self.SAMPLE_RATE)\ - / int(self.ap_meta[0]['imSampRate']) - rec_spikes = spikes[ - (cursor_duration <= spikes)\ - & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) - ] - cursor_duration - cursor += samples - - # Account for lag, in case the ephys recording was started before the - # behaviour - if not self._lag[stream_num] == None: - lag_start, _ = self._lag[stream_num] - else: - lag_start = timestamps[0] - - if lag_start < 0: - rec_spikes = rec_spikes + lag_start - - for i, start in enumerate(selected_starts): - # select spike times of current trial - trial_bool = (rec_spikes >= scan_starts[i])\ - & (rec_spikes <= scan_ends[i]) - trial = rec_spikes[trial_bool] - - # get position bin ids for current trial - trial_pos_bool = (positions.index >= start_t[i])\ - & (positions.index < end_t[i]) - trial_pos = positions[trial_pos_bool] - - # initiate binary spike times array for current trial - # NOTE: dtype must be float otherwise would get all 0 when - # passing gaussian kernel - times = np.zeros((scan_durations[i], len(units))).astype(float) - # use pixels time as spike index - idx = np.arange(scan_starts[i], scan_ends[i]) - # make it df, column name being unit id - spiked = pd.DataFrame(times, index=idx, columns=units) - - # TODO mar 5 2025: - # how to separate aligned trial times and chance, so that i can - # use cacheable to get all conditions?????? - - for unit in trial: - # get spike time for unit - u_times = trial[unit].values - # drop nan - u_times = u_times[~np.isnan(u_times)] - - # round spike times to use it as index - u_spike_idx = np.round(u_times).astype(int) - # make sure it does not exceed scan duration - if (u_spike_idx >= scan_ends[i]).any(): - beyonds = np.where(u_spike_idx >= scan_ends[i])[0] - u_spike_idx[beyonds] = idx[-1] - # make sure no double counted - u_spike_idx = np.unique(u_spike_idx) - - # set spiked to 1 - spiked.loc[u_spike_idx, unit] = 1 - - # convolve spike trains into spike rates - rates = signal.convolve_spike_trains( - times=spiked, - sigma=sigma, - sample_rate=self.SAMPLE_RATE, - ) - # remove 1s padding from the start and end - rates = rates.iloc[scan_pad: -scan_pad] - spiked = spiked.iloc[scan_pad: -scan_pad] - # reset index to zero at the beginning of the trial - rates.reset_index(inplace=True, drop=True) - rec_trials_fr[stream][trial_ids[i]] = rates - spiked.reset_index(inplace=True, drop=True) - rec_trials_spiked[stream][trial_ids[i]] = spiked - trial_pos.reset_index(inplace=True, drop=True) - trial_positions[trial_ids[i]] = trial_pos - - #if not rec_trials_fr[stream]: - # return None - - # concat trial positions - positions = ioutils.reindex_by_longest( - dfs=trial_positions, - idx_names=["trial", "time"], - level="trial", - return_format="dataframe", + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + output[stream_id] = stream.get_aligned_trials( + units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, ) - if data == "trial_times": - # get trials vertically stacked spiked - stacked_spiked = pd.concat( - rec_trials_spiked[stream], - axis=0, - ) - stacked_spiked.index.names = ["trial", "time"] - stacked_spiked.columns.names = ["unit"] - - # save index and columns to reconstruct df for shuffled data - ioutils.save_index_to_frame( - df=stacked_spiked, - path=self.interim / stream_files["shuffled_index"], - ) - ioutils.save_cols_to_frame( - df=stacked_spiked, - path=self.interim / stream_files["shuffled_columns"], - ) - - # get chance data paths - paths = { - "spiked_memmap_path": self.interim /\ - stream_files["spiked_shuffled_memmap"], - "fr_memmap_path": self.interim /\ - stream_files["fr_shuffled_memmap"], - "spiked_df_path": self.processed / stream_files["spiked_shuffled"], - "fr_df_path": self.processed / stream_files["fr_shuffled"], - } - - # save chance data - xut.save_spike_chance( - **paths, - sigma=sigma, - sample_rate=self.SAMPLE_RATE, - spiked=stacked_spiked, - ) - - # get trials horizontally stacked spiked - spiked = ioutils.reindex_by_longest( - dfs=stacked_spiked, - level="trial", - return_format="dataframe", - ) - - output[stream] = { - "spiked": spiked, - "positions": positions, - } - - elif data == "trial_rate": - # TODO apr 2 2025: make sure this reindex_by_longest works - fr = ioutils.reindex_by_longest( - dfs=rec_trials_fr[stream], - level="trial", - idx_names=["trial", "time"], - col_names=["unit"], - return_format="dataframe", - ) - - output[stream] = { - "fr": fr, - "positions": positions, - } - - # concat output into dataframe before cache - #df = pd.concat( - # objs=output, - # axis=1, - # names=["stream"], - #) return output @@ -2966,112 +2756,27 @@ def bin_aligned_trials( pos_bin: int | None For VR behaviour, size of positional bin for position data. - """ - # TODO mar 31 2025: - # use cached get_aligned_trials to bin so that we do not need to - # duplicate code - - bin_frs = {} - bin_counts = {} - bin_counts_chance = {} - - # get aligned spiked and positions - spiked = self.align_trials( - units=units, # NOTE: ALWAYS the first arg - data="trial_times", # NOTE: ALWAYS the second arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, - ) - fr = self.align_trials( - units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, - ) - + binned = {} streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - stream = stream_id[:-3] - # define output path for binned spike rate - output_fr_path = self.interim/\ - f'cache/{self.name}_{stream}_{label}_{units}_{time_bin}_spike_rate.npz' - output_count_path = self.interim/\ - f'cache/{self.name}_{stream}_{label}_{units}_{time_bin}_spike_count.npz' - if output_count_path.exists(): - continue - - logging.info(f"\n> Binning trials from {stream_id}.") - - # get stream spiked - stream_spiked = spiked[stream]["spiked"] - if stream_spiked.size == 0: - logging.info(f"\n> No units found in {units}, continue.") - return None - - # get stream positions - positions = spiked[stream]["positions"] - # get stream firing rates - stream_fr = fr[stream]["fr"] - - # TODO apr 11 2025: - # bin chance while bin data - #spiked_chance_path = self.processed / stream_files["spiked_shuffled"] - #spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") - #bin_counts_chance[stream_id] = {} - - bin_frs[stream_id] = {} - bin_counts[stream_id] = {} - for trial in positions.columns.unique(): - counts = stream_spiked.xs(trial, level="trial", axis=1).dropna() - rates = stream_fr.xs(trial, level="trial", axis=1).dropna() - trial_pos = positions[trial].dropna() - - # get bin spike count - bin_counts[stream_id][trial] = xut.bin_vr_trial( - data=counts, - positions=trial_pos, - sample_rate=self.SAMPLE_RATE, - time_bin=time_bin, - pos_bin=pos_bin, - bin_method="sum", - ) - # get bin firing rates - bin_frs[stream_id][trial] = xut.bin_vr_trial( - data=rates, - positions=trial_pos, - sample_rate=self.SAMPLE_RATE, - time_bin=time_bin, - pos_bin=pos_bin, - bin_method="mean", - ) - - # stack df values into np array - # reshape into trials x units x bins - bin_count_arr = ioutils.reindex_by_longest(bin_counts[stream_id]).T - bin_fr_arr = ioutils.reindex_by_longest(bin_frs[stream_id]).T - - # save bin_fr and bin_count, for alfredo & andrew - # use label as array key name - fr_to_save = { - "fr": bin_fr_arr[:, :-2, :], - "pos": bin_fr_arr[:, -2:, :], - } - np.savez_compressed(output_fr_path, **fr_to_save) - logging.info(f"> Output saved at {output_fr_path}.") - - count_to_save = { - "count": bin_count_arr[:, :-2, :], - "pos": bin_count_arr[:, -2:, :], - } - np.savez_compressed(output_count_path, **count_to_save) - logging.info(f"> Output saved at {output_count_path}.") + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + binned[stream_id] = stream.get_binned_trials( + label=label, + event=event, + units=units, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + ) - return None + return binned def get_chance_data(self, time_bin, pos_bin): @@ -3129,3 +2834,58 @@ def get_chance_data(self, time_bin, pos_bin): #spiked_chance = ioutils.read_hdf5(spiked_chance_path, key="spiked") return None + + + def save_spike_chance(self, spiked, sigma, sample_rate): + # TODO apr 21 2025: + # do we put this func here or in stream.py??? + + # save index and columns to reconstruct df for shuffled data + ioutils.save_index_to_frame( + df=spiked, + path=self.interim / stream_files["shuffled_index"], + ) + ioutils.save_cols_to_frame( + df=spiked, + path=self.interim / stream_files["shuffled_columns"], + ) + + # get chance data paths + paths = { + "spiked_memmap_path": self.interim /\ + stream_files["spiked_shuffled_memmap"], + "fr_memmap_path": self.interim /\ + stream_files["fr_shuffled_memmap"], + } + + # save chance data + xut.save_spike_chance( + **paths, + sigma=sigma, + sample_rate=self.SAMPLE_RATE, + spiked=spiked, + ) + + return None + + + def sync_vr(self, vr): + """ + Synchronise each pixels stream with virtual reality data. + + params + === + vr: class, virtual reality object. + """ + streams = self.files["pixels"] + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.sync_vr(vr) + + return None From 22d7948ac94d7ebc5f61c3dd53243ba8b5ef26d4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:28:28 +0100 Subject: [PATCH 313/658] use kilosort 4.0.30 --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index afa96ff..e125fa0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -812,7 +812,8 @@ def sort_spikes(self, mc_method="dredge"): (as of jan 2025, dredge performs better than ks motion correction.) "ks": do motion correction with kilosort. """ - ks_image_path = self.interim.parent/"ks4_with_wavpack.sif" + ks_image_path = self.interim.parent/"ks4-0-30_with_wavpack.sif" + #ks_image_path = self.interim.parent/"ks4-0-18_with_wavpack.sif" if not ks_image_path.exists(): raise PixelsError("Have you craeted Singularity image for sorting?") From 1ddd38dd6d246cafe497ce2bd669809f0a2ec07b Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:28:46 +0100 Subject: [PATCH 314/658] preprocess raw first if use ks motion correction --- pixels/behaviours/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e125fa0..2e0bf32 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -856,6 +856,8 @@ def sort_spikes(self, mc_method="dredge"): # load rec if ks_mc: + # preprocess raw recording + self.preprocess_raw() rec = stream_files["preprocessed"] else: rec_dir = self.find_file(stream_files["motion_corrected"]) From 3d818a3dc4bb0c4d538ec09333c867a00b15c556 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:29:16 +0100 Subject: [PATCH 315/658] raise error if noisy unit ids not saved --- pixels/behaviours/base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2e0bf32..7da6e47 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1753,10 +1753,14 @@ def select_units( sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser temp_sa = si.load_sorting_analyzer(sa_dir) + # remove noisy units - noisy_units = load_yaml( - path=self.find_file(stream_files["noisy_units"]), - ) + try: + noisy_units = load_yaml( + path=self.find_file(stream_files["noisy_units"]), + ) + except: + raise PixelsError("> Have you labelled noisy units?") # remove units from sorting and reattach to sa to keep properties sorting = temp_sa.sorting.remove_units(remove_unit_ids=noisy_units) sa = temp_sa.remove_units(remove_unit_ids=noisy_units) From ff8ee03f13401e2c7fa3cfdea502d44eea336628 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:29:52 +0100 Subject: [PATCH 316/658] add doc --- pixels/behaviours/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 7da6e47..cd85309 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1961,9 +1961,8 @@ def align_trials( 'motion_index', # Motion index per ROI from the video 'motion_tracking', # Motion tracking coordinates from DLC 'trial_rate', # Taking spike times from the whole duration of each - # trial, convolve into spike rate - 'trial_times', # Taking spike times from the whole duration of each - # trial, get spike boolean + # trial, convolve into spike rate, output also + # contains times. ] if data not in data_options: raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") From 4d316efcfedb37782b02cfb2fb574a92b89a2df8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:30:27 +0100 Subject: [PATCH 317/658] remove todo --- pixels/behaviours/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index cd85309..77ceb79 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2646,10 +2646,6 @@ def get_positional_rate( # get constants from vd from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END - # TODO dec 18 2024: - # rearrange vr specific funcs to vr module - # put pixels specific funcs in pixels module - # NOTE: order of args matters for loading the cache! # always put units first, cuz it is like that in # experiemnt.align_trials, otherwise the same cache cannot be loaded From 9e45b28d72851d8ac6273edb092bc72bf3f8d7d9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:30:58 +0100 Subject: [PATCH 318/658] move configs to configs.py --- pixels/pixels_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 299f052..9a7d596 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -24,14 +24,6 @@ from common_utils.math_utils import random_sampling from common_utils.file_utils import init_memmap, read_hdf5 -# set si job_kwargs -job_kwargs = dict( - n_jobs=0.8, # 80% core - chunk_duration='1s', - progress_bar=True, -) -si.set_global_job_kwargs(**job_kwargs) - def load_raw(paths, stream_id): """ Load raw recording file from spikeglx. From dfef9a68d64973c5a083ca9ff84f2275fc271d47 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:31:16 +0100 Subject: [PATCH 319/658] save group_ids as var --- pixels/pixels_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9a7d596..9936d69 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -50,7 +50,9 @@ def load_raw(paths, stream_id): def preprocess_raw(rec, surface_depths): - if np.unique(rec.get_channel_groups()).size < 4: + group_ids = rec.get_channel_groups() + + if np.unique(group_ids).size < 4: # correct group id if not all shanks used group_ids = correct_group_id(rec) # change the group id From 71eb9d0ac03c5f977db481179e8063498b62143d Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:32:17 +0100 Subject: [PATCH 320/658] add imports --- pixels/stream.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index df570a8..713fd05 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1,9 +1,13 @@ import numpy as np +import pandas as pd + +import spikeinterface as si from pixels import ioutils from pixels import pixels_utils as xut import pixels.signal_utils as signal from pixels.configs import * +from pixels.decorators import cacheable from common_utils import file_utils From ebc71e7626766dcf917d718b2e98ac2f783940af Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:33:06 +0100 Subject: [PATCH 321/658] add attr; implement stream level operations --- pixels/stream.py | 371 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 368 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 713fd05..283491c 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -23,20 +23,263 @@ def __init__( self.stream_num = stream_num self.probe_id = stream_id[:-3] self.files = files + self.session = session self.behaviour_files = session.files["behaviour"] - self.SAMPLE_RATE = session.SAMPLE_RATE + self.BEHAVIOUR_SAMPLE_RATE = session.SAMPLE_RATE + self.raw = session.raw + self.interim = session.interim + self.processed = session.processed + self.cache = self.interim / "cache/" + + self._use_cache = True def __repr__(self): return f"" + def load_raw_ap(self): + paths = [self.session.find_file(path) for path in self.files["ap_raw"]] + self.files["si_rec"] = xut.load_raw(paths, self.stream_id) + + return self.files["si_rec"] + + + @cacheable + def align_trials(self, units, data, label, event, sigma, end_event): + + if "trial" in data: + logging.info( + f"Aligning {data} of {units} units to <{label}> trials." + ) + return self._get_aligned_trials( + label, event, data=data, units=units, sigma=sigma, + end_event=end_event, + ) + else: + raise NotImplementedError( + "> Other types of alignment are not implemented." + ) + + def _get_aligned_trials( + self, label, event, data, units=None, sigma=None, end_event=None, + ): + # get synched pixels stream with vr and action labels + synched_vr, action_labels = self.get_synched_vr() + + # get positions of all trials + all_pos = synched_vr.position_in_tunnel + + # get spike times + spikes = self.get_spike_times(units) + + # get action and event label file + actions = action_labels["outcome"] + events = action_labels["events"] + # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to convert + # it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE + timestamps = action_labels["timestamps"] + + # select frames of wanted trial type + trials = np.where(np.bitwise_and(actions, label))[0] + # map starts by event + starts = np.where(np.bitwise_and(events, event))[0] + # map starts by end event + ends = np.where(np.bitwise_and(events, end_event))[0] + + # only take starts from selected trials + selected_starts = trials[np.where(np.isin(trials, starts))[0]] + start_t = timestamps[selected_starts] + # only take ends from selected trials + selected_ends = trials[np.where(np.isin(trials, ends))[0]] + end_t = timestamps[selected_ends] + + if selected_starts.size == 0: + logging.info(f"> No trials found with label {label} and event " + f"{event}, output will be empty.") + return None + + # use original trial id as trial index + trial_ids = synched_vr.iloc[selected_starts].trial_count.unique() + + # pad ends with 1 second extra to remove edge effects from + # convolution + scan_pad = self.BEHAVIOUR_SAMPLE_RATE + scan_starts = start_t - scan_pad + scan_ends = end_t + scan_pad + scan_durations = scan_ends - scan_starts + + cursor = 0 + raw_rec = self.load_raw_ap() + samples = raw_rec.get_total_samples() + # Account for multiple raw data files + in_SAMPLE_RATE_scale = (samples * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + cursor_duration = (cursor * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + rec_spikes = spikes[ + (cursor_duration <= spikes)\ + & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) + ] - cursor_duration + cursor += samples + + output = {} + trials_fr = {} + trials_spiked = {} + trials_positions = {} + for i, start in enumerate(selected_starts): + # select spike times of current trial + trial_bool = (rec_spikes >= scan_starts[i])\ + & (rec_spikes <= scan_ends[i]) + trial = rec_spikes[trial_bool] + # get position bin ids for current trial + trial_pos_bool = (all_pos.index >= start_t[i])\ + & (all_pos.index < end_t[i]) + trial_pos = all_pos[trial_pos_bool] + + # initiate binary spike times array for current trial + # NOTE: dtype must be float otherwise would get all 0 when passing + # gaussian kernel + times = np.zeros((scan_durations[i], len(units))).astype(float) + # use pixels time as spike index + idx = np.arange(scan_starts[i], scan_ends[i]) + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) + + # TODO mar 5 2025: how to separate aligned trial times and chance, + # so that i can use cacheable to get all conditions?????? + for unit in trial: + # get spike time for unit + u_times = trial[unit].values + # drop nan + u_times = u_times[~np.isnan(u_times)] + # round spike times to use it as index + u_spike_idx = np.round(u_times).astype(int) + # make sure it does not exceed scan duration + if (u_spike_idx >= scan_ends[i]).any(): + beyonds = np.where(u_spike_idx >= scan_ends[i])[0] + u_spike_idx[beyonds] = idx[-1] + # make sure no double counted + u_spike_idx = np.unique(u_spike_idx) + + # set spiked to 1 + spiked.loc[u_spike_idx, unit] = 1 + + # convolve spike trains into spike rates + rates = signal.convolve_spike_trains( + times=spiked, + sigma=sigma, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + ) + + # remove 1s padding from the start and end + rates = rates.iloc[scan_pad: -scan_pad] + spiked = spiked.iloc[scan_pad: -scan_pad] + + # reset index to zero at the beginning of the trial + rates.reset_index(inplace=True, drop=True) + trials_fr[trial_ids[i]] = rates + spiked.reset_index(inplace=True, drop=True) + trials_spiked[trial_ids[i]] = spiked + trial_pos.reset_index(inplace=True, drop=True) + trials_positions[trial_ids[i]] = trial_pos + + # concat trial positions + positions = ioutils.reindex_by_longest( + dfs=trials_positions, + idx_names=["trial", "time"], + level="trial", + return_format="dataframe", + ) + + # get trials vertically stacked spiked + stacked_spiked = pd.concat( + trials_spiked, + axis=0, + ) + stacked_spiked.index.names = ["trial", "time"] + stacked_spiked.columns.names = ["unit"] + + # TODO apr 21 2025: + # save spike chance only if all units are selected, else + # only index into the big chance array and save into zarr + #if units.name == "all" and (label == 725 or 1322): + # self.save_spike_chance( + # stream_files=stream_files, + # spiked=stacked_spiked, + # sigma=sigma, + # ) + #else: + # # access chance data if we only need part of the units + # self.get_spike_chance( + # sample_rate=self.SAMPLE_RATE, + # positions=all_pos, + # sigma=sigma, + # ) + # assert 0 + + # get trials horizontally stacked spiked + spiked = ioutils.reindex_by_longest( + dfs=stacked_spiked, + level="trial", + return_format="dataframe", + ) + fr = ioutils.reindex_by_longest( + dfs=trials_fr, + level="trial", + idx_names=["trial", "time"], + col_names=["unit"], + return_format="dataframe", + ) + + output["spiked"] = spiked + output["fr"] = fr + output["positions"] = positions + + return output + + + def get_spike_times(self, units): + # find sorting analyser path + sa_path = self.session.find_file(self.files["sorting_analyser"]) + # load sorting analyser + temp_sa = si.load_sorting_analyzer(sa_path) + # select units + sorting = temp_sa.sorting.select_units(units) + sa = temp_sa.select_units(units) + sa.sorting = sorting + + times = {} + # get spike train + for i, unit_id in enumerate(sa.unit_ids): + unit_times = sa.sorting.get_unit_spike_train( + unit_id=unit_id, + return_times=False, + ) + times[unit_id] = pd.Series(unit_times) + + # concatenate units + spike_times = pd.concat( + objs=times, + axis=1, + names="unit", + ) + # get sampling frequency + fs = int(sa.sampling_frequency) + # Convert to time into sample rate index + spike_times /= fs / self.BEHAVIOUR_SAMPLE_RATE + + return spike_times + def sync_vr(self, vr): # get action labels action_labels = self.session.get_action_labels()[self.stream_num] - if action_labels: + synched_vr_path = self.session.find_file( + self.behaviour_files['vr_synched'][self.stream_num], + ) + if action_labels and synched_vr_path: logging.info(f"\n> {self.stream_id} from {self.session.name} is " - "already synched with vr, continue.") + "already synched with vr, now loaded.") else: _sync_vr(self, vr) @@ -129,6 +372,128 @@ def _sync_vr(self, vr): return None + def get_synched_vr(self): + """ + Get synchronised vr data and action labels. + """ + action_labels = self.session.get_action_labels()[self.stream_num] + + synched_vr_path = self.session.find_file( + self.behaviour_files["vr_synched"][self.stream_num], + ) + synched_vr = file_utils.read_hdf5(synched_vr_path) + + return synched_vr, action_labels + + + def get_binned_trials( + self, label, event, units=None, sigma=None, end_event=None, + time_bin=None, pos_bin=None + ): + # define output path for binned spike rate + output_path = self.cache/ f"{self.session.name}_{label}_{units}_"\ + f"{time_bin}_{pos_bin}cm_{self.stream_id}.npz" + + if output_path.exists(): + binned = np.load(output_path) + logging.info( + f"\n> <{label}> trials from {self.stream_id} in {units} " + "already binned, now loaded." + ) + else: + binned = self._bin_aligned_trials( + label=label, + event=event, + units=units, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + output_path=output_path, + ) + + return binned + + + def _bin_aligned_trials( + self, label, event, units, sigma, end_event, time_bin, pos_bin, + output_path, + ): + # get aligned trials + trials = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + + logging.info( + f"\n> Binning label <{label}> trials from {self.stream_id} " + f"in {units}." + ) + + # get fr, spiked, positions + fr = trials["fr"] + spiked = trials["spiked"] + positions = trials["positions"] + + if spiked.size == 0: + logging.info(f"\n> No units found in {units}, continue.") + assert 0 + return None + + # TODO apr 11 2025: + # bin chance while bin data + #spiked_chance_path = self.processed / stream_files["spiked_shuffled"] + #spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") + #bin_counts_chance[stream_id] = {} + + binned = {} + binned_count = {} + binned_fr = {} + for trial in positions.columns.unique(): + counts = spiked.xs(trial, level="trial", axis=1).dropna() + rates = fr.xs(trial, level="trial", axis=1).dropna() + trial_pos = positions[trial].dropna() + + # get bin spike count + binned_count[trial] = xut.bin_vr_trial( + data=counts, + positions=trial_pos, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", + ) + # get bin firing rates + binned_fr[trial] = xut.bin_vr_trial( + data=rates, + positions=trial_pos, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", + ) + + # stack df values into np array + # reshape into trials x units x bins + binned_count = ioutils.reindex_by_longest(binned_count).T + binned_fr = ioutils.reindex_by_longest(binned_fr).T + + # save bin_fr and bin_count, for andrew + # use label as array key name + binned["count"] = binned_count[:, :-2, :] + binned["fr"] = binned_fr[:, :-2, :] + binned["pos"] = binned_count[:, -2:, :] + + np.savez_compressed(output_path, **binned) + logging.info(f"\n> Output saved at {output_path}.") + + return binned + + def save_spike_chance(self, spiked, sigma): # TODO apr 21 2025: # do we put this func here or in stream.py?? From 8789d778d9c7e6565e1eb871c9a6dd7197fca477 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:33:48 +0100 Subject: [PATCH 322/658] make to use the right name --- pixels/stream.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 283491c..2d7714e 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -313,7 +313,7 @@ def _sync_vr(self, vr): downsampled = signal.decimate( array=syncs, from_hz=spike_samp_rate, - to_hz=self.SAMPLE_RATE, + to_hz=self.BEHAVIOUR_SAMPLE_RATE, ) # binarise to avoid non integers pixels_syncs = signal.binarise(downsampled) @@ -340,20 +340,20 @@ def _sync_vr(self, vr): pixels_idx = np.arange(pixels_syncs.shape[0]) synched_vr = vr.sync_streams( - self.SAMPLE_RATE, + self.BEHAVIOUR_SAMPLE_RATE, pixels_vr_edges, pixels_idx, )[vr_session.name] # save to pixels processed dir file_utils.write_hdf5( - self.session.processed /\ + self.processed /\ self.behaviour_files['vr_synched'][self.stream_num], synched_vr, ) # get action label dir - action_labels_path = self.session.processed /\ + action_labels_path = self.processed /\ self.behaviour_files["action_labels"][self.stream_num] # extract and save action labels @@ -519,7 +519,7 @@ def save_spike_chance(self, spiked, sigma): xut.save_spike_chance( **paths, sigma=sigma, - sample_rate=self.SAMPLE_RATE, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, spiked=spiked, ) From 0e44d68fda8a018c4074fe3017f59916cf7635c3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:34:01 +0100 Subject: [PATCH 323/658] put wrapper cacheable in decorators.py --- pixels/decorators.py | 100 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 pixels/decorators.py diff --git a/pixels/decorators.py b/pixels/decorators.py new file mode 100644 index 0000000..0e50e64 --- /dev/null +++ b/pixels/decorators.py @@ -0,0 +1,100 @@ +import numpy as np +import pandas as pd +from tables import HDF5ExtError + +from pixels.configs import * +from pixels import ioutils +from pixels.error import PixelsError +from pixels.units import SelectedUnits + +def cacheable(method): + """ + Methods with this decorator will have their output cached to disk so that + future calls with the same set of arguments will simply load the result from + disk. However, from pixels.error import PixelsError if the key word argument + list contains `units` and it is not either `None` or an instance of + `SelectedUnits` then this is disabled. + """ + def wrapper(*args, **kwargs): + name = kwargs.pop("name", None) + + if "units" in kwargs: + units = kwargs["units"] + if not isinstance(units, SelectedUnits) or not hasattr(units, "name"): + return method(*args, **kwargs) + + self, *as_list = list(args) + list(kwargs.values()) + if not self._use_cache: + return method(*args, **kwargs) + + arrays = [i for i, arg in enumerate(as_list) if isinstance(arg, np.ndarray)] + if arrays: + if name is None: + raise PixelsError( + "Cacheing methods when passing arrays requires also " + "passing name='something'" + ) + for i in arrays: + as_list[i] = name + + # build a key: method name + all args + key_parts = [method.__name__] + [str(i) for i in as_list] + cache_path = self.cache /\ + ("_".join(key_parts) + f"_{self.stream_id}.h5") + + if cache_path.exists() and self._use_cache != "overwrite": + # load cache + try: + df = ioutils.read_hdf5(cache_path) + logging.info(f"> Cache loaded from {cache_path}.") + except HDF5ExtError: + df = None + except (KeyError, ValueError): + # if key="df" is not found, then use HDFStore to list and read + # all dfs + # create df as a dictionary to hold all dfs + df = {} + with pd.HDFStore(cache_path, "r") as store: + # list all keys + for key in store.keys(): + # remove "/" in key and split + parts = key.lstrip("/").split("/") + if len(parts) == 1: + # use the only key name as dict key + df[parts[0]] = store[key] + elif len(parts) == 2: + # stream id is the first, data name is the second + stream, name = parts[0], "/".join(parts[1:]) + df.setdefault(stream, {})[name] = store[key] + logging.info(f"> Cache loaded from {cache_path}.") + else: + df = method(*args, **kwargs) + cache_path.parent.mkdir(parents=True, exist_ok=True) + if df is None: + cache_path.touch() + logging.info("\n> df is None, cache will exist but be empty.") + else: + # allows to save multiple dfs in a dict in one hdf5 file + if ioutils.is_nested_dict(df): + for probe_id, nested_dict in df.items(): + # NOTE: we remove `.ap` in stream id cuz having `.`in + # the key name get problems + for name, values in nested_dict.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=f"/{probe_id}/{name}", + mode="a", + ) + elif isinstance(df, dict): + for name, values in df.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=name, + mode="a", + ) + else: + ioutils.write_hdf5(cache_path, df) + return df + return wrapper From 41cfbe5c4b6b0a08e5d36b961b9b723ac27de0c9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 30 Apr 2025 18:34:22 +0100 Subject: [PATCH 324/658] put SelectedUnits class separately --- pixels/units.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 pixels/units.py diff --git a/pixels/units.py b/pixels/units.py new file mode 100644 index 0000000..366aee9 --- /dev/null +++ b/pixels/units.py @@ -0,0 +1,12 @@ +class SelectedUnits(list): + name: str + """ + This is the return type of Behaviour.select_units, which is a list in every way + except that when represented as a string, it can return a name, if a `name` + attribute has been set on it. This allows methods that have had `units` passed to be + cached to file. + """ + def __repr__(self): + if hasattr(self, "name"): + return self.name + return list.__repr__(self) From ff1f719a12d3ff10d9f757cf97bd648ca24efc24 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 May 2025 15:29:33 +0100 Subject: [PATCH 325/658] formatting logging --- pixels/behaviours/base.py | 93 +++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 77ceb79..b8e3272 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -231,7 +231,7 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: raw = self.raw / name if raw.exists(): if copy: - logging.info(f" {self.name}: Copying {name} to interim") + logging.info(f"\n {self.name}: Copying {name} to interim") copyfile(raw, interim) return interim return raw @@ -239,11 +239,11 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: tar = raw.with_name(raw.name + '.tar.gz') if tar.exists(): if copy: - logging.info(f" {self.name}: Extracting {tar.name} to interim") + logging.info(f"\n {self.name}: Extracting {tar.name} to interim") with tarfile.open(tar) as open_tar: open_tar.extractall(path=self.interim) return interim - logging.info(f" {self.name}: Extracting {tar.name}") + logging.info(f"\n {self.name}: Extracting {tar.name}") with tarfile.open(tar) as open_tar: open_tar.extractall(path=self.raw) return raw @@ -274,11 +274,11 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): # TODO jan 14 2025: # this func is not used in vr behaviour, since they are synched # in vd.session - logging.info(" Finding lag between sync channels") + logging.info("\n Finding lag between sync channels") recording = self.files[rec_num] if behavioural_data is None: - logging.info(" Loading behavioural data") + logging.info("\n Loading behavioural data") data_file = self.find_file(recording['behaviour']) behavioural_data = ioutils.read_tdms(data_file, groups=["NpxlSync_Signal"]) behavioural_data = signal.resample( @@ -286,7 +286,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): ) if sync_channel is None: - logging.info(" Loading neuropixels sync channel") + logging.info("\n Loading neuropixels sync channel") data_file = self.find_file(recording['lfp_data']) num_chans = self.lfp_meta[rec_num]['nSavedChans'] sync_channel = ioutils.read_bin(data_file, num_chans, channel=384) @@ -297,7 +297,7 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): behavioural_data = signal.binarise(behavioural_data) sync_channel = signal.binarise(sync_channel) - logging.info(" Finding lag") + logging.info("\n Finding lag") plot = self.processed / f'sync_{rec_num}.png' lag_start, match = signal.find_sync_lag( behavioural_data, sync_channel, plot=plot, @@ -307,9 +307,9 @@ def sync_data(self, rec_num, behavioural_data=None, sync_channel=None): self._lag[rec_num] = (lag_start, lag_end) if match < 95: - logging.warning(" The sync channels did not match very well. " + logging.warning("\n The sync channels did not match very well. " "Check the plot.") - logging.info(f" Calculated lag: {(lag_start, lag_end)}") + logging.info(f"\n Calculated lag: {(lag_start, lag_end)}") lag_json = [] for lag in self._lag: @@ -467,10 +467,11 @@ def process_behaviour(self): # this func is not used by vr behaviour for rec_num, recording in enumerate(self.files): logging.info( - f">>>>> Processing behaviour for recording {rec_num + 1} of {len(self.files)}" + f"\n>>>>> Processing behaviour for recording {rec_num + 1}" + f" of {len(self.files)}" ) - logging.info(f"> Loading behavioural data") + logging.info(f"\n> Loading behavioural data") behavioural_data = ioutils.read_tdms(self.find_file(recording['behaviour'])) # ignore any columns that have Nans; these just contain settings @@ -478,12 +479,12 @@ def process_behaviour(self): if behavioural_data[col].isnull().values.any(): behavioural_data.drop(col, axis=1, inplace=True) - logging.info(f"> Downsampling to {self.SAMPLE_RATE} Hz") + logging.info(f"\n> Downsampling to {self.SAMPLE_RATE} Hz") behav_array = signal.resample(behavioural_data.values, 25000, self.SAMPLE_RATE) behavioural_data.iloc[:len(behav_array), :] = behav_array behavioural_data = behavioural_data[:len(behav_array)] - logging.info(f"> Syncing to Neuropixels data") + logging.info(f"\n> Syncing to Neuropixels data") if self._lag[rec_num] is None: self.sync_data( rec_num, @@ -493,20 +494,20 @@ def process_behaviour(self): behavioural_data = behavioural_data[max(lag_start, 0):-1-max(lag_end, 0)] behavioural_data.index = range(len(behavioural_data)) - logging.info(f"> Extracting action labels") + logging.info(f"\n> Extracting action labels") self._action_labels[rec_num] = self._extract_action_labels(rec_num, behavioural_data) output = self.processed / recording['action_labels'] np.save(output, self._action_labels[rec_num]) - logging.info(f"> Saved to: {output}") + logging.info(f"\n> Saved to: {output}") output = self.processed / recording['behaviour_processed'] - logging.info(f"> Saving downsampled behavioural data to:") - logging.info(f" {output}") + logging.info(f"\n> Saving downsampled behavioural data to:") + logging.info(f"\n {output}") behavioural_data.drop("/'NpxlSync_Signal'/'0'", axis=1, inplace=True) ioutils.write_hdf5(output, behavioural_data) self._behavioural_data[rec_num] = behavioural_data - logging.info("> Done!") + logging.info("\n> Done!") def correct_motion(self, mc_method="dredge"): """ @@ -524,7 +525,7 @@ def correct_motion(self, mc_method="dredge"): None """ if mc_method == "ks": - logging.info(f"> Correct motion later with {mc_method}.") + logging.info(f"\n> Correct motion later with {mc_method}.") return None # get pixels streams @@ -533,7 +534,7 @@ def correct_motion(self, mc_method="dredge"): for stream_id, stream_files in streams.items(): output = self.interim / stream_files["motion_corrected"] if output.exists(): - logging.info(f"> Motion corrected {stream_id} loaded.") + logging.info(f"\n> Motion corrected {stream_id} loaded.") continue # preprocess raw recording @@ -628,7 +629,7 @@ def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): for stream_id, stream_files in streams.items(): output = self.processed / stream_files["detected_peaks"] if output.exists(): - logging.info(f"> Peaks from {stream_id} already detected.") + logging.info(f"\n> Peaks from {stream_id} already detected.") continue # get ap band @@ -661,14 +662,14 @@ def extract_bands(self, freqs=None): for name, freqs in bands.items(): output = self.processed / stream_files[f"{name}_extracted"] if output.exists(): - logging.info(f"> {name} bands from {stream_id} loaded.") + logging.info(f"\n> {name} bands from {stream_id} loaded.") continue # preprocess raw data self.preprocess_raw() logging.info( - f">>>>> Extracting {name} bands from {self.name} " + f"\n>>>>> Extracting {name} bands from {self.name} " f"{stream_id} in total of {self.stream_count} stream(s)" ) @@ -694,7 +695,7 @@ def extract_bands(self, freqs=None): self.sync_data(rec_num, sync_channel=data[:, -1]) lag_start, lag_end = self._lag[rec_num] - logging.info(f"> Saving data to {output}") + logging.info(f"\n> Saving data to {output}") if lag_end < 0: data = data[:lag_end] if lag_start < 0: @@ -838,9 +839,9 @@ def sort_spikes(self, mc_method="dredge"): # check if already sorted and exported sa_dir = self.processed / stream_files["sorting_analyser"] if not sa_dir.exists(): - logging.info(f"> {self.name} {stream_id} not sorted/exported.\n") + logging.info(f"\n> {self.name} {stream_id} not sorted/exported.") else: - logging.info("> Already sorted and exported, next session.\n") + logging.info("\n> Already sorted and exported, next session.") continue # get catgt directory @@ -928,7 +929,7 @@ def configure_motion_tracking(self, project: str) -> None: copy_videos=False, ) else: - logging.warning(f"Config not found.") + logging.warning(f"\nConfig not found.") reply = input("Create new project? [Y/n]") if reply and reply[0].lower() == "n": raise PixelsError("A DLC project is needed for motion tracking.") @@ -1560,7 +1561,7 @@ def get_spike_times(self, units, remapped=False, use_si=False): ) repeats = c_times[np.where(counts>1)] if len(repeats>1): - logging.info(f"> removed {len(repeats)} double-counted " + logging.info(f"\n> removed {len(repeats)} double-counted " "spikes from cluster {c}.") by_clust[c] = pd.Series(uniques) @@ -1626,7 +1627,7 @@ def _get_aligned_spike_times( if len(centre) == 0: # See comment in align_trials as to why we just continue instead of # erroring like we used to here. - logging.info("No event found for an action. If this is OK, " + logging.info("\nNo event found for an action. If this is OK, " "ignore this.") continue centre = start + centre[0] @@ -1968,7 +1969,7 @@ def align_trials( raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") if data in ("spike_times", "spike_rate"): - logging.info(f"Aligning {data} to trials.") + logging.info(f"\nAligning {data} to trials.") # we let a dedicated function handle aligning spike times return self._get_aligned_spike_times( label, event, duration, rate=data == "spike_rate", sigma=sigma, @@ -1976,7 +1977,9 @@ def align_trials( ) if "trial" in data: - logging.info(f"Aligning {data} of {units} units to trials.") + logging.info( + f"\nAligning {data} of {units} units to label <{label}> trials." + ) return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, end_event=end_event, @@ -1988,14 +1991,14 @@ def align_trials( action_labels = self.get_action_labels() if raw: - logging.info(f"Aligning raw {data} data to trials.") + logging.info(f"\nAligning raw {data} data to trials.") getter = getattr(self, f"get_{data}_data_raw", None) if not getter: raise PixelsError(f"align_trials: {data} doesn't have a 'raw' option.") values, SAMPLE_RATE = getter() else: - logging.info(f"Aligning {data} data to trials.") + logging.info(f"\nAligning {data} data to trials.") if dlc_project: values = self.get_motion_tracking_data(dlc_project) elif data == "motion_index": @@ -2041,7 +2044,8 @@ def align_trials( # here to warn the user in case it is an error, while otherwise # continuing. #raise PixelsError('Action labels probably miscalculated') - logging.info("No event found for an action. If this is OK, ignore this.") + logging.info("\nNo event found for an action. If this is" + " OK, ignore this.") continue centre = start + centre[0] centre = int(centre * SAMPLE_RATE / self.SAMPLE_RATE) @@ -2136,7 +2140,7 @@ def align_clips(self, label, event, video_match, duration=1): for start in trial_starts: centre = np.where(np.bitwise_and(events[start:start + scan_duration], event))[0] if len(centre) == 0: - logging.info("No event found for an action. If this is OK, ignore this.") + logging.info("\nNo event found for an action. If this is OK, ignore this.") continue centre = start + centre[0] frames = timings.loc[ @@ -2180,7 +2184,7 @@ def get_good_units_info(self): if self._good_unit_info is None: #az: good_units_info.tsv saved while running depth_profile.py info_file = self.interim / 'good_units_info.tsv' - #logging.info(f"> got good unit info at {info_file}\n") + #logging.info(f"\n> got good unit info at {info_file}\n") try: info = pd.read_csv(info_file, sep='\t') @@ -2198,7 +2202,7 @@ def get_spike_widths(self, units=None): all_widths = self.get_spike_widths() return all_widths.loc[all_widths.unit.isin(units)] - logging.info("Calculating spike widths") + logging.info("\nCalculating spike widths") waveforms = self.get_spike_waveforms() widths = [] @@ -2251,7 +2255,7 @@ def get_spike_waveforms(self, units=None, method='phy'): rec_forms = {} for u, unit in enumerate(units): - logging.info(f"{round(100 * u / len(units), 2)}% complete") + logging.info(f"\n{round(100 * u / len(units), 2)}% complete") # get the waveforms from only the best channel spike_ids = model.get_cluster_spikes(unit) best_chan = model.get_cluster_channels(unit)[0] @@ -2299,13 +2303,13 @@ def get_spike_waveforms(self, units=None, method='phy'): ks_mod_time = os.path.getmtime(self.ks_outputs / 'cluster_info.tsv') assert template_cache_mod_time < ks_mod_time check = True # re-extract waveforms - logging.info("> Re-extracting waveforms since kilosort output is newer.") + logging.info("\n> Re-extracting waveforms since kilosort output is newer.") except: if 'template_cache_mod_time' in locals(): - logging.info("> Loading existing waveforms.") + logging.info("\n> Loading existing waveforms.") check = False # load existing waveforms else: - logging.info("> Extracting waveforms since they are not extracted.") + logging.info("\n> Extracting waveforms since they are not extracted.") check = True # re-extract waveforms """ @@ -2368,7 +2372,7 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): # normalise these metrics before passing to k-means columns = ["unit", "duration", "trough_peak_ratio", "half_width", "repolarisation_slope", "recovery_slope"] - logging.info(f"> Calculating waveform metrics {columns[1:]}...\n") + logging.info(f"\n> Calculating waveform metrics {columns[1:]}...\n") waveforms = self.get_spike_waveforms() # remove nan values @@ -2418,7 +2422,7 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): # repolarisation slope returns = np.where(mean_waveform.iloc[trough_idx:] >= 0) + trough_idx if len(returns[0]) == 0: - logging.info(f"> The mean waveformrns never returned to baseline?\n") + logging.info(f"\n> The mean waveformrns never returned to baseline?\n") return_idx = mean_waveform.shape[0] - 1 else: return_idx = returns[0][0] @@ -2559,7 +2563,8 @@ def get_aligned_spike_rate_CI( for trial, t_start, t_end in zip(trials, start, end): if not (t_start < t_end): logging.warning( - f"Warning: trial {trial} skipped in CI calculation due to bad timepoints" + f"\nWarning: trial {trial} skipped in CI calculation" + " due to bad timepoints" ) continue trial_responses.append( From deef4d3ecca88accc79c297a778d8619c11b447b Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 May 2025 15:29:58 +0100 Subject: [PATCH 326/658] temporarily comment out cacheable funcs --- pixels/behaviours/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b8e3272..b6f90b0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1900,7 +1900,7 @@ def get_lfp_data_raw(self): """ return self._get_neuro_raw('lfp') - @_cacheable + #@_cacheable def align_trials( self, label, event, units=None, data='spike_times', raw=False, duration=1, sigma=None, dlc_project=None, video_match=None, @@ -2194,7 +2194,7 @@ def get_good_units_info(self): self._good_unit_info = info return self._good_unit_info - @_cacheable + #@_cacheable def get_spike_widths(self, units=None): if units: # Always defer to getting widths for all units, so we only ever have to @@ -2225,7 +2225,7 @@ def get_spike_widths(self, units=None): df['median_ms'] = 1000 * df['median_ms'] / orig_rate return df - @_cacheable + #@_cacheable def get_spike_waveforms(self, units=None, method='phy'): """ Extracts waveforms of spikes. @@ -2346,7 +2346,7 @@ def get_spike_waveforms(self, units=None, method='phy'): not implemented!") - @_cacheable + #@_cacheable def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): """ This func is a work-around of spikeinterface's equivalent: @@ -2460,7 +2460,7 @@ def get_waveform_metrics(self, units=None, window=20, upsampling_factor=10): return df - @_cacheable + #@_cacheable def get_aligned_spike_rate_CI( self, label, event, start=0.000, step=0.100, end=1.000, @@ -2640,7 +2640,7 @@ def get_aligned_spike_rate_CI( return df - @_cacheable + #@_cacheable def get_positional_rate( self, label, event, end_event=None, sigma=None, units=None, ): From fa9c953b7a59e210f4571066888960c594e3f367 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 May 2025 15:31:16 +0100 Subject: [PATCH 327/658] formatting logging --- pixels/behaviours/virtual_reality.py | 12 +++---- pixels/decorators.py | 4 +-- pixels/ioutils.py | 2 +- pixels/pixels_utils.py | 52 ++++++++++++++++------------ pixels/stream.py | 8 ++--- 5 files changed, 43 insertions(+), 35 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 99ebcbc..34abf50 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -145,7 +145,7 @@ def _extract_action_labels(self, vr, vr_data): trial_dark = (vr_data.trial_type == Conditions.DARK) # <<<< definitions <<<< - logging.info(">> Mapping vr event times...") + logging.info("\n>> Mapping vr event times...") # >>>> gray >>>> # get timestamps of gray @@ -222,8 +222,8 @@ def _extract_action_labels(self, vr, vr_data): action_labels[trial_starts, 1] += Events.trial_start if not trial_starts.size == vr_data.trial_count.max(): - raise PixelsError(f"Number of trials does not equal to\ - \n{vr_data.trial_count.max()}.") + raise PixelsError(f"Number of trials does not equal to " + "{vr_data.trial_count.max()}.") # NOTE: if trial starts at 0, the first position_in_tunnel value will # NOT be nan @@ -263,7 +263,7 @@ def _extract_action_labels(self, vr, vr_data): # TODO jun 27 2024 positional events and valve events needs mapping - logging.info(">> Mapping vr action times...") + logging.info("\n>> Mapping vr action times...") # >>>> map reward types >>>> # get non-zero reward types @@ -299,8 +299,8 @@ def _extract_action_labels(self, vr, vr_data): assert (trial == vr_data.trial_count.unique().max()) assert (vr_data[of_trial].position_in_tunnel.max()\ < vr.tunnel_reset) - logging.info(f"> trial {trial} is unfinished when session ends, " - "so there is no outcome.") + logging.info(f"\n> trial {trial} is unfinished when session " + "ends, so there is no outcome.") # <<<< unfinished trial <<<< else: # >>>> non punished >>>> diff --git a/pixels/decorators.py b/pixels/decorators.py index 0e50e64..3dda450 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -46,7 +46,7 @@ def wrapper(*args, **kwargs): # load cache try: df = ioutils.read_hdf5(cache_path) - logging.info(f"> Cache loaded from {cache_path}.") + logging.info(f"\n> Cache loaded from {cache_path}.") except HDF5ExtError: df = None except (KeyError, ValueError): @@ -66,7 +66,7 @@ def wrapper(*args, **kwargs): # stream id is the first, data name is the second stream, name = parts[0], "/".join(parts[1:]) df.setdefault(stream, {})[name] = store[key] - logging.info(f"> Cache loaded from {cache_path}.") + logging.info(f"\n> Cache loaded from {cache_path}.") else: df = method(*args, **kwargs) cache_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index c93c315..200826f 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -414,7 +414,7 @@ def write_hdf5(path, df, key="df", mode="w", format="fixed"): complib="blosc:lz4hc", ) - print("HDF5 saved to ", path) + print("HDF5 saved to", path) return diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9936d69..c729614 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -31,7 +31,7 @@ def load_raw(paths, stream_id): recs = [] for p, path in enumerate(paths): # NOTE: if it is catgt data, pass directly `catgt_ap_data` - print(f"\n> Getting the orignial recording...") + logging.info(f"\n> Getting the orignial recording...") # load # recording # file rec = se.read_spikeglx( folder_path=path.parent, @@ -64,7 +64,7 @@ def preprocess_raw(rec, surface_depths): # split by groups groups = rec.split_by("group") for g, group in groups.items(): - print(f"> Preprocessing shank {g}") + logging.info(f"\n> Preprocessing shank {g}") # get brain surface depth of shank surface_depth = surface_depths[g] cleaned = _preprocess_raw(group, surface_depth) @@ -93,18 +93,18 @@ def _preprocess_raw(rec, surface_depth): Implementation of preprocessing on raw pixels data. """ # correct phase shift - print("\t> step 1: do phase shift correction.") + logging.info("\n\t> step 1: do phase shift correction.") rec_ps = spre.phase_shift(rec) # remove bad channels from sorting - print("\t> step 2: remove bad channels.") + logging.info("\n\t> step 2: remove bad channels.") bad_chan_ids, chan_labels = spre.detect_bad_channels( rec_ps, outside_channels_location="top", ) labels, counts = np.unique(chan_labels, return_counts=True) for label, count in zip(labels, counts): - print(f"\t\t> Found {count} channels labelled as {label}.") + logging.info(f"\n\t\t> Found {count} channels labelled as {label}.") rec_removed = rec_ps.remove_channels(bad_chan_ids) # get channel group id and use it to index into brain surface channel depth @@ -116,9 +116,9 @@ def _preprocess_raw(rec, surface_depth): # remove channels outside by using identified brain surface depths outside_chan_ids = chan_ids[chan_depths > surface_depth] rec_clean = rec_removed.remove_channels(outside_chan_ids) - print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") + logging.info(f"\n\t\t> Removed {outside_chan_ids.size} outside channels.") - print("\t> step 3: do common median referencing.") + logging.info("\n\t> step 3: do common median referencing.") # NOTE: dtype will be converted to float32 during motion correction cmr = spre.common_reference( rec_clean, @@ -142,7 +142,7 @@ def correct_motion(rec, mc_method="dredge"): === None """ - print(f"\t> correct motion with {mc_method}.") + logging.info(f"\t> correct motion with {mc_method}.") # reduce spatial window size for four-shank estimate_motion_kwargs = { "win_step_um": 100, @@ -189,7 +189,7 @@ def detect_n_localise_peaks(rec, loc_method="monopolar_triangulation"): groups = rec.split_by("group") dfs = [] for g, group in groups.items(): - print(f"\n> Estimate drift of shank {g}") + logging.info(f"\n> Estimate drift of shank {g}") dfs.append(_detect_n_localise_peaks(group, loc_method)) # concat shanks df = pd.concat( @@ -219,7 +219,7 @@ def _detect_n_localise_peaks(rec, loc_method): from spikeinterface.sortingcomponents.peak_localization\ import localize_peaks - print("> step 1: detect peaks") + logging.info("\n> step 1: detect peaks") peaks = detect_peaks( recording=rec, method="by_channel", @@ -227,8 +227,10 @@ def _detect_n_localise_peaks(rec, loc_method): exclude_sweep_ms=0.2, ) - print("> step 2: localize the peaks to get a sense of their putative " - "depths") + logging.info( + "\n> step 2: localize the peaks to get a sense of their putative " + "depths" + ) peak_locations = localize_peaks( recording=rec, peaks=peaks, @@ -640,7 +642,7 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, return === """ - print(f"Processing repeat {i}...") + logging.info(f"\nProcessing repeat {i}...") # open readonly memmap spiked = init_memmap( path=concat_spiked_path, @@ -676,7 +678,7 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, chance_spiked.flush() chance_fr.flush() - print(f"Repeat {i} finished.") + logging.info(f"\nRepeat {i} finished.") return None @@ -691,7 +693,7 @@ def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, sample_rate, repeats, spiked, spiked_shape, concat_spiked_path) else: - print(f"> Spike chance already saved at {fr_df_path}, continue.") + logging.info(f"\n> Spike chance already saved at {fr_df_path}, continue.") return None @@ -776,8 +778,10 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, sigma, sample_rate, for future in concurrent.futures.as_completed(futures): future.result() else: - print("\n> Memmaps already created, only need to convert into " - "dataframes and save.") + logging.info( + "\n> Memmaps already created, only need to convert into " + "dataframes and save." + ) # convert it to dataframe and save it #save_chance( @@ -789,7 +793,7 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, sigma, sample_rate, # fr_df_path=fr_df_path, # d_shape=d_shape, #) - #print(f"\n> Chance data saved to {fr_df_path}.") + #logging.info(f"\n> Chance data saved to {fr_df_path}.") return None @@ -861,7 +865,7 @@ def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, === orig_idx: pandas """ - print(f"\n> Saving chance data...") + logging.info(f"\n> Saving chance data...") # get chance spiked df _convert_to_df( @@ -1063,11 +1067,15 @@ def correct_group_id(rec): # map bool channel x locations shank_bool = np.isin(x_locs, shank_x) if np.any(shank_bool) == False: - print(f"\n> Recording does not have shank {shank_id}, continue.") + logging.info( + f"\n> Recording does not have shank {shank_id}, continue." + ) continue group_ids[shank_bool] = shank_id - print("\n> Not all shanks used in multishank probe, change group ids into " - f"{np.unique(group_ids)}.") + logging.info( + "\n> Not all shanks used in multishank probe, change group ids into " + f"{np.unique(group_ids)}." + ) return group_ids diff --git a/pixels/stream.py b/pixels/stream.py index 2d7714e..8c0c3d5 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -49,7 +49,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): if "trial" in data: logging.info( - f"Aligning {data} of {units} units to <{label}> trials." + f"\n> Aligning {data} of {units} units to <{label}> trials." ) return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, @@ -94,7 +94,7 @@ def _get_aligned_trials( end_t = timestamps[selected_ends] if selected_starts.size == 0: - logging.info(f"> No trials found with label {label} and event " + logging.info(f"\n> No trials found with label {label} and event " f"{event}, output will be empty.") return None @@ -300,7 +300,7 @@ def _sync_vr(self, vr): try: synched_vr = file_utils.read_hdf5(synched_vr_path) - logging.info("> synchronised vr loaded") + logging.info("\n> synchronised vr loaded") except: # get sync pulses sync_map = ioutils.read_bin(spike_data, 385, 384) @@ -367,7 +367,7 @@ def _sync_vr(self, vr): events=action_labels[:, 1], timestamps=action_labels[:, 2], ) - logging.info(f"> Action labels saved to: {action_labels_path}.") + logging.info(f"\n> Action labels saved to: {action_labels_path}.") return None From 37463144fe5abfcbc05aaecb4a95e36477aada4f Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 May 2025 15:31:34 +0100 Subject: [PATCH 328/658] modify doc and make sure use the right attr --- pixels/stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 8c0c3d5..03eaeaa 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -272,16 +272,16 @@ def get_spike_times(self, units): def sync_vr(self, vr): - # get action labels + # get action labels & synched vr path action_labels = self.session.get_action_labels()[self.stream_num] synched_vr_path = self.session.find_file( self.behaviour_files['vr_synched'][self.stream_num], ) if action_labels and synched_vr_path: logging.info(f"\n> {self.stream_id} from {self.session.name} is " - "already synched with vr, now loaded.") + "already synched with vr.") else: - _sync_vr(self, vr) + self._sync_vr(vr) return None From aed28f40df3fc9a805740d526f094e022214fb5b Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 2 May 2025 15:32:07 +0100 Subject: [PATCH 329/658] check if data exists before indexing --- pixels/stream.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 03eaeaa..ccfd0da 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -429,6 +429,11 @@ def _bin_aligned_trials( end_event=end_event, ) + if trials is None: + logging.info(f"\n> No trials found with label {label} and event " + f"{event}, output will be empty.") + return None + logging.info( f"\n> Binning label <{label}> trials from {self.stream_id} " f"in {units}." @@ -439,11 +444,6 @@ def _bin_aligned_trials( spiked = trials["spiked"] positions = trials["positions"] - if spiked.size == 0: - logging.info(f"\n> No units found in {units}, continue.") - assert 0 - return None - # TODO apr 11 2025: # bin chance while bin data #spiked_chance_path = self.processed / stream_files["spiked_shuffled"] From 1419160c0bd66ac4d63aca2055766662229082d3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 5 May 2025 14:42:55 +0100 Subject: [PATCH 330/658] add documentation --- pixels/stream.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index ccfd0da..bf9d536 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -46,10 +46,41 @@ def load_raw_ap(self): @cacheable def align_trials(self, units, data, label, event, sigma, end_event): + """ + Align pixels data to behaviour trials. + + params + === + units : list of lists of ints, optional + The output from self.select_units, used to only apply this method to a + selection of units. + + data : str, optional + The data type to align. + + label : int + An action label value to specify which trial types are desired. + + event : int + An event type value to specify which event to align the trials to. + + sigma : int, optional + Time in milliseconds of sigma of gaussian kernel to use when + aligning firing rates. + + end_event : int | None + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. + + return + === + df, output from individual functions according to data type. + """ - if "trial" in data: + if "trial_rate" in data: logging.info( - f"\n> Aligning {data} of {units} units to <{label}> trials." + f"\n> Aligning trial_times and {data} of {units} units to " + f"<{label}> trials." ) return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, From dfb54864da51198ae3ea87a03ef811bd37e35cda Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 5 May 2025 14:43:22 +0100 Subject: [PATCH 331/658] implement cacheable positional data --- pixels/stream.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index bf9d536..0b7675d 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -525,6 +525,34 @@ def _bin_aligned_trials( return binned + @cacheable + def get_positional_data( + self, label, event, end_event=None, sigma=None, units=None, + ): + """ + Get positional firing rate of selected units in vr, and spatial + occupancy of each position. + """ + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + # get aligned firing rates and positions + trials = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + + # get positional spike rate, spike count, and occupancy + positional_data = xut.get_vr_positional_data(trials) + + return positional_data + + def save_spike_chance(self, spiked, sigma): # TODO apr 21 2025: # do we put this func here or in stream.py?? From f50fccbb2cea7012487f88fd1cb3e207c4286b9a Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 5 May 2025 14:43:47 +0100 Subject: [PATCH 332/658] implement getting positional data at stream level --- pixels/behaviours/base.py | 114 ++++++++------------------------------ 1 file changed, 22 insertions(+), 92 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b6f90b0..eb51b2b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2640,113 +2640,43 @@ def get_aligned_spike_rate_CI( return df - #@_cacheable - def get_positional_rate( + def get_positional_data( self, label, event, end_event=None, sigma=None, units=None, ): """ Get positional firing rate of selected units in vr, and spatial occupancy of each position. """ - # get constants from vd - from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END - # NOTE: order of args matters for loading the cache! # always put units first, cuz it is like that in # experiemnt.align_trials, otherwise the same cache cannot be loaded - # get aligned firing rates and positions - trials = self.align_trials( - units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, - ) - fr = trials["fr"] - positions = trials["positions"] - - # get unit_ids - unit_ids = fr.columns.get_level_values("unit").unique() - - # create position indices - indices = np.arange(0, TUNNEL_RESET+2) - # create occupancy array for trials - occupancy = np.full( - (TUNNEL_RESET+2, positions.shape[1]), - np.nan, - ) - # create array for positional firing rate - pos_fr = {} - - for t, trial in enumerate(positions): - # get trial position - trial_pos = positions[trial].dropna() - - # floor pre reward zone and end ceil post zone end - trial_pos = trial_pos.apply( - lambda x: np.floor(x) if x <= ZONE_END else np.ceil(x) + output = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, ) - # set to int - trial_pos = trial_pos.astype(int) - - # exclude positions after tunnel reset - trial_pos = trial_pos[trial_pos <= TUNNEL_RESET+1] - # get firing rates for current trial of all units - trial_fr = fr.xs( - key=trial, - axis=1, - level="trial", - ).dropna(how="all").copy() - - # get all indices before post reset - no_post_reset = trial_fr.index.intersection(trial_pos.index) - # remove post reset rows - trial_fr = trial_fr.loc[no_post_reset] - trial_pos = trial_pos.loc[no_post_reset] - - # put trial positions in trial fr df - trial_fr["position"] = trial_pos.values - # group values by position and get mean - mean_fr = trial_fr.groupby("position")[unit_ids].mean() - # reindex into full tunnel length - pos_fr[trial] = mean_fr.reindex(indices) - # get trial occupancy - pos_count = trial_fr.groupby("position").size() - occupancy[pos_count.index.values, t] = pos_count.values - - # concatenate dfs - pos_fr = pd.concat(pos_fr, axis=1, names=["trial", "unit"]) - # convert to df - occupancy = pd.DataFrame( - data=occupancy, - index=indices, - columns=positions.columns, - ) - - # add another level of starting position - # Get the starting index for each trial (column) - starts = occupancy.apply(lambda col: col.first_valid_index()) - # Group trials by their starting index - trial_level = pos_fr.columns.get_level_values("trial") - unit_level = pos_fr.columns.get_level_values("unit") - # map start level - start_level = trial_level.map(starts) - # define new columns - new_cols = pd.MultiIndex.from_arrays( - [start_level, unit_level, trial_level], - names=["start", "unit", "trial"], - ) - pos_fr.columns = new_cols - # sort by unit - pos_fr = pos_fr.sort_index(level="unit", axis=1) + logging.info( + f"\n> Getting positional neural data of {units} units in " + f"<{label}> trials." + ) + output[stream_id] = stream.get_positional_data( + units=units, # NOTE: put units first! + label=label, + event=event, + end_event=end_event, + sigma=sigma, + ) - return {"pos_fr": pos_fr, "occupancy": occupancy} + return output - def bin_aligned_trials( + def get_binned_trials( self, label, event, units=None, sigma=None, end_event=None, time_bin=None, pos_bin=None, ): From a1b259c6c299a8c39f72fa08dba5aae2b5eeaea3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 5 May 2025 14:44:36 +0100 Subject: [PATCH 333/658] get vr positional neural data --- pixels/pixels_utils.py | 142 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c729614..23624ba 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -21,7 +21,7 @@ from pixels.error import PixelsError from pixels.configs import * -from common_utils.math_utils import random_sampling +from common_utils.math_utils import random_sampling, group_and_aggregate from common_utils.file_utils import init_memmap, read_hdf5 def load_raw(paths, stream_id): @@ -1079,3 +1079,143 @@ def correct_group_id(rec): ) return group_ids + + +def get_vr_positional_data(trial_data): + """ + Get positional firing rate and spike count for VR behaviour. + + params + === + trial_data: pandas df, output from align_trials. + + return + === + dict, positional firing rate, positional spike count, positional occupancy, + data in 1cm resolution. + """ + pos_fr, occupancy = _get_vr_positional_neural_data( + positions=trial_data["positions"], + data_type="spike_rate", + data=trial_data["fr"], + ) + pos_fc, _ = _get_vr_positional_neural_data( + positions=trial_data["positions"], + data_type="spiked", + data=trial_data["spiked"], + ) + + return {"pos_fr": pos_fr, "pos_fc": pos_fc, "occupancy": occupancy} + + +def _get_vr_positional_neural_data(positions, data_type, data): + """ + Get positional neural data for VR behaviour. + + params + === + positions: pandas df, vr positions of all trials. + shape: time x trials. + + data_type: str, type of neural data. + "spike_rate": firing rate of each unit in each trial. + "spiked": spike boolean of each unit in each trial. + + data: pandas df, aligned trial firing rate or spike boolean. + shape: time x (unit x trial) + levels: unit, trial + + return + === + pos_data: pandas df, positional neural data. + shape: position x (num of starting positions x unit x trial) + levels: start, unit, trial + + occupancy: pandas df, count of each position. + shape: position x trial + """ + logging.info(f"\n> Getting positional {data_type}...") + + # get constants from vd + from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END + + # get the starting index for each trial (column) + starts = positions.iloc[0, :].astype(int) + # create position indices + indices = np.arange(0, TUNNEL_RESET+2) + # create occupancy array for trials + occupancy = np.full( + (TUNNEL_RESET+2, positions.shape[1]), + np.nan, + ) + + pos_data = {} + for t, trial in enumerate(positions): + # get trial position + trial_pos = positions[trial].dropna() + + # floor pre reward zone and end ceil post zone end + trial_pos = trial_pos.apply( + lambda x: np.floor(x) if x <= ZONE_END else np.ceil(x) + ) + # set to int + trial_pos = trial_pos.astype(int) + + # exclude positions after tunnel reset + trial_pos = trial_pos[trial_pos <= TUNNEL_RESET+1] + + # get firing rates for current trial of all units + trial_data = data.xs( + key=trial, + axis=1, + level="trial", + ).dropna(how="all").copy() + + # get all indices before post reset + no_post_reset = trial_data.index.intersection(trial_pos.index) + # remove post reset rows + trial_data = trial_data.loc[no_post_reset] + trial_pos = trial_pos.loc[no_post_reset] + + # put trial positions in trial data df + trial_data["position"] = trial_pos.values + + if data_type == "spike_rate": + # group values by position and get mean data + how = "mean" + elif data_type == "spiked": + # group values by position and get sum data + how = "sum" + grouped_data = group_and_aggregate(trial_data, "position", "sum") + + # reindex into full tunnel length + pos_data[trial] = grouped_data.reindex(indices) + # get trial occupancy + pos_count = trial_data.groupby("position").size() + occupancy[pos_count.index.values, t] = pos_count.values + + # concatenate dfs + pos_data = pd.concat(pos_data, axis=1, names=["trial", "unit"]) + # convert to df + occupancy = pd.DataFrame( + data=occupancy, + index=indices, + columns=positions.columns, + ) + + # add another level of starting position + # group trials by their starting index + trial_level = pos_data.columns.get_level_values("trial") + unit_level = pos_data.columns.get_level_values("unit") + # map start level + start_level = trial_level.map(starts) + # define new columns + new_cols = pd.MultiIndex.from_arrays( + [start_level, unit_level, trial_level], + names=["start", "unit", "trial"], + ) + pos_data.columns = new_cols + # sort by unit + pos_data = pos_data.sort_index(level="unit", axis=1) + + return pos_data, occupancy From c3564a3a43d28a6c89ef09f76876bd7b92497007 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:02:13 +0100 Subject: [PATCH 334/658] allows multiple sessions on the same day --- pixels/ioutils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 200826f..4b0d546 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -465,11 +465,20 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): ]) if of_date is not None: - date_struct = datetime.datetime.strptime(of_date, session_date_fmt) - mouse_sessions = [mouse_sessions[session_dates.index(date_struct)]] - print(f"\n> Getting 1 session from {mouse} of " + if isinstance(of_date, str): + date_list = [of_date] + else: + date_list = of_date + + date_sessions = [] + for date in date_list: + date_struct = datetime.datetime.strptime(date, session_date_fmt) + date_sessions.append(mouse_sessions[session_dates.index(date_struct)]) + logging.info( + f"\n> Getting one session from {mouse} on " f"{datetime.datetime.strftime(date_struct, '%Y %B %d')}." - ) + ) + mouse_sessions = date_sessions if not meta_dir: # Do not collect metadata From eff5493a4a7fbdaff863c450b205e5311de7277b Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:27:09 +0100 Subject: [PATCH 335/658] use configs --- pixels/ioutils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 4b0d546..7e99441 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -17,6 +17,7 @@ from nptdms import TdmsFile from pixels.error import PixelsError +from pixels.configs import * def get_data_files(data_dir, session_name): From 145438c52e9a17264dd3e5e25e19da9d5072ec08 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:27:32 +0100 Subject: [PATCH 336/658] get algined trials across sessions --- pixels/ioutils.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 7e99441..0a46d34 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -845,3 +845,47 @@ def save_cols_to_frame(df, path): ) # NOTE: to reconstruct: # df.columns = col_df.values + + +def get_aligned_data_across_sessions(trials, key, level_names): + """ + Get aligned trials across sessions. + + params + === + trials: nested dict, aligned trials from multiple sessions. + keys: session_name -> stream_id -> "fr", "positions", "spiked" + + key: str, type of data to get. + "fr": firing rate, longest trial time x (unit x trial) + "spiked": spiked boolean, longest trial time x (unit x trial) + "positions": trial positions, longest trial time x trial + + return + === + df: pandas df, concatenated key data across sessions. + """ + per_session = {} + for s_name, s_data in trials.items(): + key_data = {} + for stream_id, stream_data in s_data.items(): + key_data[stream_id] = stream_data[key] + + # concat at stream level + per_session[s_name] = pd.concat( + key_data, + axis=1, + names=level_names[1:], + ) + + # concat at session level + df = pd.concat( + per_session, + axis=1, + names=level_names, + ) + + # swap stream and session so that stream is the most outer level + output = df.swaplevel("session", "stream", axis=1) + + return output From e8cfab7000008add51476ce31e9e14451ed0f1af Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:28:19 +0100 Subject: [PATCH 337/658] use dict not list so that session name can be the key --- pixels/experiment.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 03282e1..a1db682 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -221,10 +221,10 @@ def select_units(self, *args, **kwargs): Select units based on specified criteria. The output of this can be passed to some other methods to apply those methods only to these units. """ - units = [] + units = {} for i, session in enumerate(self.sessions): - units.append(session.select_units(*args, **kwargs)) + units[session.name] = session.select_units(*args, **kwargs) return units @@ -235,14 +235,19 @@ def align_trials(self, *args, units=None, **kwargs): """ trials = {} for i, session in enumerate(self.sessions): + name = session.name result = None if units: - if units[i]: - result = session.align_trials(*args, units=units[i], **kwargs) + if units[name]: + result = session.align_trials( + *args, + units=units[name], + **kwargs, + ) else: result = session.align_trials(*args, **kwargs) if result is not None: - trials[i] = result + trials[name] = result if "motion_tracking" in args: df = pd.concat( From 425e12be992b7c5dd25c77df22ef0139d2ad84ae Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:28:56 +0100 Subject: [PATCH 338/658] get different aligned data across sessions --- pixels/experiment.py | 139 +++++++++++++++++++++++++++---------------- 1 file changed, 89 insertions(+), 50 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index a1db682..138166a 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -256,32 +256,27 @@ def align_trials(self, *args, units=None, **kwargs): names=["session", "trial", "scorer", "bodyparts", "coords"] ) - # TODO apr 3 2025: - # make sure trial_times is here too if "trial_rate" in kwargs.values(): - frs = {} - positions = {} - for s in trials: - frs[s] = trials[s]["fr"] - positions[s] = trials[s]["positions"] - - frs_df = pd.concat( - frs.values(), - axis=1, - copy=False, - keys=frs.keys(), - names=["session"] + level_names = ["session", "stream", "unit", "trial"] + fr = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="fr", + level_names=level_names, ) - pos_df = pd.concat( - positions.values(), - axis=1, - copy=False, - keys=positions.keys(), - names=["session"] + spiked = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="spiked", + level_names=level_names, + ) + positions = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="positions", + level_names=["session", "stream", "trial"], ) df = { - "fr": frs_df, - "positions": pos_df, + "fr": fr, + "spiked": spiked, + "positions": positions, } else: df = pd.concat( @@ -489,50 +484,94 @@ def get_session_by_name(self, name: str): raise PixelsError - def get_positional_rate(self, *args, units=None, **kwargs): + def get_positional_data(self, *args, units=None, **kwargs): """ Get positional firing rate for aligned vr trials. - Check behaviours.base.Behaviour.align_trials for usage information. + Check behaviours.base.Behaviour.get_positional_data for usage + information. """ trials = {} for i, session in enumerate(self.sessions): + name = session.name result = None if units: - if units[i]: - result = session.get_positional_rate( + if units[name]: + result = session.get_positional_data( *args, - units=units[i], + units=units[name], **kwargs, ) else: - result = session.get_positional_rate(*args, **kwargs) + result = session.get_positional_data(*args, **kwargs) if result is not None: - trials[i] = result - - pos_frs = {} - occupancies = {} - for s in trials: - pos_frs[s] = trials[s]["pos_fr"] - occupancies[s] = trials[s]["occupancy"] - - pos_frs_df = pd.concat( - pos_frs.values(), - axis=1, - copy=False, - keys=pos_frs.keys(), - names=["session"] + trials[name] = result + + level_names = ["session", "stream", "start", "unit", "trial"] + pos_fr = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="pos_fr", + level_names=level_names, ) - occu_df = pd.concat( - occupancies.values(), - axis=1, - copy=False, - keys=occupancies.keys(), - names=["session"] + pos_fc = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="pos_fc", + level_names=level_names, + ) + occupancies = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="occupancy", + level_names=["session", "stream", "trial"], ) df = { - "pos_fr": pos_frs_df, - "occupancy": occu_df, + "pos_fr": pos_fr, + "pos_fc": pos_fc, + "occupancy": occupancies, } return df + + def get_binned_trials(self, *args, units=None, **kwargs): + """ + Get binned firing rate and spike count for aligned vr trials. + Check behaviours.base.Behaviour.get_binned_trials for usage information. + """ + trials = {} + for i, session in enumerate(self.sessions): + name = session.name + result = None + if units: + if units[name]: + result = session.get_binned_trials( + *args, + units=units[name], + **kwargs, + ) + else: + result = session.get_binned_trials(*args, **kwargs) + if result is not None: + trials[name] = result + + level_names = ["session", "stream", "unit", "trial"] + bin_fr = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="bin_fr", + level_names=level_names, + ) + bin_fc = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="bin_fc", + level_names=level_names, + ) + bin_pos = ioutils.get_aligned_data_across_sessions( + trials=trials, + key="bin_pos", + level_names=["session", "stream", "pos_type", "trial"], + ) + df = { + "bin_fr": bin_fr, + "bin_fc": bin_fc, + "bin_pos": bin_pos, + } + + return df From dd6f20386b46dfd0a43cb5906b760603b5b2b6ba Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:29:26 +0100 Subject: [PATCH 339/658] add logging info --- pixels/behaviours/base.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index eb51b2b..d07d7ac 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -794,6 +794,10 @@ def load_raw_ap(self): files=stream_files, session=self, ) + + logging.info( + f"\n> Loading raw {stream_id} data." + ) raw_rec = stream.load_raw_ap() # now the value for streams dict is recording extractor @@ -1978,7 +1982,8 @@ def align_trials( if "trial" in data: logging.info( - f"\nAligning {data} of {units} units to label <{label}> trials." + f"\n> Aligning {self.name} {data} of {units} units to label " + f"<{label}> trials." ) return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, @@ -2703,6 +2708,11 @@ def get_binned_trials( files=stream_files, session=self, ) + + logging.info( + f"\n> Getting binned <{label}> trials from {stream_id} " + f"in {units}." + ) binned[stream_id] = stream.get_binned_trials( label=label, event=event, @@ -2823,6 +2833,10 @@ def sync_vr(self, vr): files=stream_files, session=self, ) + + logging.info( + f"\n> Synchonising pixels data with vr." + ) stream.sync_vr(vr) return None From c5cdc2457429ec150c163155c52c5b59c2e7ab67 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:29:36 +0100 Subject: [PATCH 340/658] typo --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d07d7ac..746e119 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1698,7 +1698,7 @@ def _get_aligned_trials( files=stream_files, session=self, ) - output[stream_id] = stream.get_aligned_trials( + output[stream_id] = stream.align_trials( units=units, # NOTE: ALWAYS the first arg data="trial_rate", # NOTE: ALWAYS the second arg label=label, From 7060f5be40ba7a1ccb307bbadebda00cc57b7d2e Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:30:05 +0100 Subject: [PATCH 341/658] more clear logging info --- pixels/stream.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 0b7675d..868a07e 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -37,6 +37,7 @@ def __init__( def __repr__(self): return f"" + def load_raw_ap(self): paths = [self.session.find_file(path) for path in self.files["ap_raw"]] self.files["si_rec"] = xut.load_raw(paths, self.stream_id) @@ -79,7 +80,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): if "trial_rate" in data: logging.info( - f"\n> Aligning trial_times and {data} of {units} units to " + f"\n> Aligning spike times and spike rate of {units} units to " f"<{label}> trials." ) return self._get_aligned_trials( @@ -466,7 +467,7 @@ def _bin_aligned_trials( return None logging.info( - f"\n> Binning label <{label}> trials from {self.stream_id} " + f"\n> Binning <{label}> trials from {self.stream_id} " f"in {units}." ) From 7d209b2712385709d85f251cd3e1779df856d362 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:30:18 +0100 Subject: [PATCH 342/658] cache binned trials --- pixels/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/stream.py b/pixels/stream.py index 868a07e..a7b44ad 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -418,6 +418,7 @@ def get_synched_vr(self): return synched_vr, action_labels + @cacheable def get_binned_trials( self, label, event, units=None, sigma=None, end_event=None, time_bin=None, pos_bin=None From 67123b8c36090386420a4e9240a90f39ea174cde Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:30:42 +0100 Subject: [PATCH 343/658] explicitly name arrays and return df format --- pixels/stream.py | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index a7b44ad..94f3d83 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -483,7 +483,7 @@ def _bin_aligned_trials( #spiked_chance = ioutils.read_hdf5(spiked_chance_path, "spiked") #bin_counts_chance[stream_id] = {} - binned = {} + bin_arr = {} binned_count = {} binned_fr = {} for trial in positions.columns.unique(): @@ -512,19 +512,46 @@ def _bin_aligned_trials( # stack df values into np array # reshape into trials x units x bins - binned_count = ioutils.reindex_by_longest(binned_count).T - binned_fr = ioutils.reindex_by_longest(binned_fr).T + count_arr = ioutils.reindex_by_longest(binned_count).T + fr_arr = ioutils.reindex_by_longest(binned_fr).T # save bin_fr and bin_count, for andrew # use label as array key name - binned["count"] = binned_count[:, :-2, :] - binned["fr"] = binned_fr[:, :-2, :] - binned["pos"] = binned_count[:, -2:, :] + bin_arr["count"] = count_arr[:, :-2, :] + bin_arr["fr"] = fr_arr[:, :-2, :] + bin_arr["pos"] = count_arr[:, -2:, :] - np.savez_compressed(output_path, **binned) + np.savez_compressed(output_path, **bin_arr) logging.info(f"\n> Output saved at {output_path}.") - return binned + # extract binned data in df format + bin_fc, bin_pos = self._extract_binned_data(binned_count) + bin_fr, _ = self._extract_binned_data(binned_fr) + + return {"bin_fc": bin_fc, "bin_fr": bin_fr, "bin_pos": bin_pos} + + + def _extract_binned_data(self, binned_data): + """ + """ + df = ioutils.reindex_by_longest( + dfs=binned_data, + idx_names=["trial", "time_bin"], + col_names=["unit"], + return_format="dataframe", + ) + data = df.drop( + labels=["positions", "bin_pos"], + axis=1, + level="unit", + ) + pos = df.filter( + like="pos", + axis=1, + ) + pos.columns.names = ["pos_type", "trial"] + + return data, pos @cacheable From e7d74be3a1adb0f6ee36dac8cc1bd7dfa11d6594 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:44:36 +0100 Subject: [PATCH 344/658] use print for within attr logs --- pixels/pixels_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 23624ba..d8572f2 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -93,18 +93,18 @@ def _preprocess_raw(rec, surface_depth): Implementation of preprocessing on raw pixels data. """ # correct phase shift - logging.info("\n\t> step 1: do phase shift correction.") + print("\n\t> step 1: do phase shift correction.") rec_ps = spre.phase_shift(rec) # remove bad channels from sorting - logging.info("\n\t> step 2: remove bad channels.") + print("\n\t> step 2: remove bad channels.") bad_chan_ids, chan_labels = spre.detect_bad_channels( rec_ps, outside_channels_location="top", ) labels, counts = np.unique(chan_labels, return_counts=True) for label, count in zip(labels, counts): - logging.info(f"\n\t\t> Found {count} channels labelled as {label}.") + print(f"\n\t\t> Found {count} channels labelled as {label}.") rec_removed = rec_ps.remove_channels(bad_chan_ids) # get channel group id and use it to index into brain surface channel depth @@ -116,9 +116,9 @@ def _preprocess_raw(rec, surface_depth): # remove channels outside by using identified brain surface depths outside_chan_ids = chan_ids[chan_depths > surface_depth] rec_clean = rec_removed.remove_channels(outside_chan_ids) - logging.info(f"\n\t\t> Removed {outside_chan_ids.size} outside channels.") + print(f"\n\t\t> Removed {outside_chan_ids.size} outside channels.") - logging.info("\n\t> step 3: do common median referencing.") + print("\n\t> step 3: do common median referencing.") # NOTE: dtype will be converted to float32 during motion correction cmr = spre.common_reference( rec_clean, From 22398dc0b65cd22f806f4ce344b80910fe4be440 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 6 May 2025 15:45:05 +0100 Subject: [PATCH 345/658] make sure to start on a new line --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d8572f2..5b96d41 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -142,7 +142,7 @@ def correct_motion(rec, mc_method="dredge"): === None """ - logging.info(f"\t> correct motion with {mc_method}.") + logging.info(f"\n\t> correct motion with {mc_method}.") # reduce spatial window size for four-shank estimate_motion_kwargs = { "win_step_um": 100, From ee700ec8ce8a0349318a6d9f46f7e8e1ef45e83b Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 8 May 2025 13:15:19 +0100 Subject: [PATCH 346/658] lfp band till 500Hz --- pixels/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/constants.py b/pixels/constants.py index 8f52d89..9ee37c3 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -7,7 +7,7 @@ freq_bands = { "ap":[300, 9000], - "lfp":[0.5, 300], + "lfp":[0.5, 500], } BEHAVIOUR_HZ = 25000 From 421174a7c039a0fdebb432d8a25bb1b8c6ee047c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 8 May 2025 13:16:02 +0100 Subject: [PATCH 347/658] formatting --- pixels/pixels_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 5b96d41..157c45d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -93,18 +93,18 @@ def _preprocess_raw(rec, surface_depth): Implementation of preprocessing on raw pixels data. """ # correct phase shift - print("\n\t> step 1: do phase shift correction.") + print("\t> step 1: do phase shift correction.") rec_ps = spre.phase_shift(rec) # remove bad channels from sorting - print("\n\t> step 2: remove bad channels.") + print("\t> step 2: remove bad channels.") bad_chan_ids, chan_labels = spre.detect_bad_channels( rec_ps, outside_channels_location="top", ) labels, counts = np.unique(chan_labels, return_counts=True) for label, count in zip(labels, counts): - print(f"\n\t\t> Found {count} channels labelled as {label}.") + print(f"\t\t> Found {count} channels labelled as {label}.") rec_removed = rec_ps.remove_channels(bad_chan_ids) # get channel group id and use it to index into brain surface channel depth @@ -116,9 +116,9 @@ def _preprocess_raw(rec, surface_depth): # remove channels outside by using identified brain surface depths outside_chan_ids = chan_ids[chan_depths > surface_depth] rec_clean = rec_removed.remove_channels(outside_chan_ids) - print(f"\n\t\t> Removed {outside_chan_ids.size} outside channels.") + print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") - print("\n\t> step 3: do common median referencing.") + print("\t> step 3: do common median referencing.") # NOTE: dtype will be converted to float32 during motion correction cmr = spre.common_reference( rec_clean, From 64a8c520894a6bf42a9d0f6b45d880410f4a6ffd Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:38:36 +0100 Subject: [PATCH 348/658] inherit more dirs from session --- pixels/stream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 94f3d83..5b93c0a 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -29,8 +29,9 @@ def __init__( self.BEHAVIOUR_SAMPLE_RATE = session.SAMPLE_RATE self.raw = session.raw self.interim = session.interim - self.processed = session.processed self.cache = self.interim / "cache/" + self.processed = session.processed + self.histology = session.histology self._use_cache = True From e9982e733640fb4abb42cae387cce1ec66f7d0b1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:38:51 +0100 Subject: [PATCH 349/658] also import constants --- pixels/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/stream.py b/pixels/stream.py index 5b93c0a..b3b164f 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -7,6 +7,7 @@ from pixels import pixels_utils as xut import pixels.signal_utils as signal from pixels.configs import * +from pixels.constants import * from pixels.decorators import cacheable from common_utils import file_utils From d10d515331faecc94c1cf5e3a0ce2d0b0a2a1bcc Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:40:51 +0100 Subject: [PATCH 350/658] cache binned trials hence no need to check for existence --- pixels/stream.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index b3b164f..364f430 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -428,24 +428,16 @@ def get_binned_trials( # define output path for binned spike rate output_path = self.cache/ f"{self.session.name}_{label}_{units}_"\ f"{time_bin}_{pos_bin}cm_{self.stream_id}.npz" - - if output_path.exists(): - binned = np.load(output_path) - logging.info( - f"\n> <{label}> trials from {self.stream_id} in {units} " - "already binned, now loaded." - ) - else: - binned = self._bin_aligned_trials( - label=label, - event=event, - units=units, - sigma=sigma, - end_event=end_event, - time_bin=time_bin, - pos_bin=pos_bin, - output_path=output_path, - ) + binned = self._bin_aligned_trials( + label=label, + event=event, + units=units, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + output_path=output_path, + ) return binned From 32f0a449c970acb210856ecba2b71be6ab51ca25 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:41:41 +0100 Subject: [PATCH 351/658] implement stream-level preprocessing, band extraction, ap motion correction, whitening, and spike sorting --- pixels/stream.py | 112 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 364f430..22753d4 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -576,6 +576,118 @@ def get_positional_data( return positional_data + def preprocess_raw(self): + # load raw ap + raw_rec = self.load_raw_ap() + + # load brain surface depths + depth_info = file_utils.load_yaml( + path=self.histology / self.files["depth_info"], + ) + surface_depths = depth_info["raw_signal_depths"][self.stream_id] + + # preprocess + self.files["preprocessed"] = xut.preprocess_raw( + raw_rec, + surface_depths, + ) + + return None + + + def extract_bands(self, freqs): + self.preprocess_raw() + + if freqs == None: + bands = freq_bands + elif isinstance(freqs, str) and freqs in freq_bands.keys(): + bands = {freqs: freq_bands[freqs]} + elif isinstance(freqs, dict): + bands = freqs + + for name, freqs in bands.items(): + logging.info( + f"\n> Extracting {name} bands from {self.stream_id}." + ) + # do bandpass filtering + self.files[f"{name}_extracted"] = xut.extract_band( + self.files["preprocessed"], + freq_min=freqs[0], + freq_max=freqs[1], + ) + + return None + + + def correct_ap_motion(self): + # get ap band + self.extract_bands("ap") + ap_rec = self.files["ap_extracted"] + + # correct ap motion + self.files["ap_motion_corrected"] = xut.correct_ap_motion(ap_rec) + + return None + + + def correct_lfp_motion(self): + raise NotImplementedError("> Not implemented.") + + + def whiten_ap(self): + # get motion corrected ap + mcd = self.files["ap_motion_corrected"] + + # whiten + self.files["ap_whitened"] = xut.whiten(mcd) + + return None + + + def sort_spikes(self, ks_mc, ks4_params, ks_image_path, output, sa_dir): + """ + Sort spikes of stream. + + params + === + ks_mc: bool, whether using kilosort 4 innate motion correction. + + ks4_params: dict, kilosort 4 parameters. + + output: path, directory to save sorting output. + + sa_dir: path, directory to save sorting analyser. + + return + === + + """ + # use only preprocessed if use ks motion correction + if ks_mc: + self.preprocess_raw() + rec = self.files["preprocessed"] + sa_rec = None + else: + # whiten ap band and feed to ks + rec = self.files["ap_whitened"] + # use non-whitened recording for sorting analyser + #sa_rec = self.files["ap_motion_corrected"] + sa_rec = rec + # TODO may 13 2025: test building sa with whitened + + # sort spikes and save sorting analyser to disk + xut.sort_spikes( + rec=rec, + sa_rec=sa_rec, + output=output, + curated_sa_dir=sa_dir, + ks_image_path=ks_image_path, + ks4_params=ks4_params, + ) + + return None + + def save_spike_chance(self, spiked, sigma): # TODO apr 21 2025: # do we put this func here or in stream.py?? From adcdcfdede823006ed51104279687a83f5bf9dc5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:42:17 +0100 Subject: [PATCH 352/658] add different versions of kilosort 4 --- pixels/configs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pixels/configs.py b/pixels/configs.py index 6f00399..17dec3b 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -28,3 +28,8 @@ level=3, # high compression bps=None, # lossless ) + +# kilosort 4 singularity image names +ks4_0_30_image_name = "si102.3_ks4-0-30_with_wavpack.sif" +ks4_0_18_image_name = "ks4-0-18_with_wavpack.sif" +ks4_image_name = ks4_0_30_image_name From 5604b17fa5b8fe75dafc306a94ed1980900b61f2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:42:53 +0100 Subject: [PATCH 353/658] preprocessing done on the fly hence it'll always be si obj --- pixels/ioutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 0a46d34..4837fb6 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -47,7 +47,7 @@ def get_data_files(data_dir, session_name): "imec0":{ "ap_raw": [PosixPath("name.bin")], "ap_meta": [PosixPath("name.meta")], - "preprocessed": PosixPath("name.zarr"), + "preprocessed": spikeinterface recording obj, "ap_downsampled": PosixPath("name.zarr"), "lfp_downsampled": PosixPath("name.zarr"), "surface_depth": PosixPath("name.yaml"), From 5460ed7be2b95d55c670b0a4ffb977aac9d28657 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:47:26 +0100 Subject: [PATCH 354/658] only save motion corrected rec, ap & lfp separately, preprocess & whiten on the fly --- pixels/ioutils.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 4837fb6..49a5e53 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -48,8 +48,9 @@ def get_data_files(data_dir, session_name): "ap_raw": [PosixPath("name.bin")], "ap_meta": [PosixPath("name.meta")], "preprocessed": spikeinterface recording obj, - "ap_downsampled": PosixPath("name.zarr"), - "lfp_downsampled": PosixPath("name.zarr"), + "ap_extracted": spikeinterface recording obj, + "ap_whitened": spikeinterface recording obj, + "lfp_extracted": spikeinterface recording obj, "surface_depth": PosixPath("name.yaml"), "sorting_analyser": PosixPath("name.zarr"), }, @@ -85,6 +86,9 @@ def get_data_files(data_dir, session_name): "ap_meta": [], "si_rec": None, # there could be only one, thus None "preprocessed": None, + "ap_extracted": None, + "ap_whitened": None, + "lfp_extracted": None, "CatGT_ap_data": [], "CatGT_ap_meta": [], } @@ -93,28 +97,25 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["ap_raw"].append(base_name) pixels[stream_id]["ap_meta"].append(original_name(ap_meta[r])) - # spikeinterface cache - pixels[stream_id]["motion_corrected"] = base_name.with_name( + # >>> spikeinterface cache >>> + # extracted & motion corrected ap stream, 300Hz+ + pixels[stream_id]["ap_motion_corrected"] = base_name.with_name( f"{base_name.stem}.mcd.zarr" ) + # extracted & motion corrected lfp stream, 500Hz- + pixels[stream_id]["lfp_motion_corrected"] = base_name.with_name( + f"{base_name.stem[:-3]}.lf.mcd.zarr" + ) pixels[stream_id]["detected_peaks"] = base_name.with_name( f"{base_name.stem}_detected_peaks.h5" ) pixels[stream_id]["sorting_analyser"] = base_name.parent/\ - f"sorted_stream_{stream_id[-4]}/curated_sa.zarr" - - # extracted ap stream, 300Hz+ - pixels[stream_id]["ap_extracted"] = base_name.with_name( - f"{base_name.stem}.extracted.zarr" - ) - # extracted lfp stream, 300Hz- - pixels[stream_id]["lfp_extracted"] = base_name.with_name( - f"{base_name.stem[:-3]}.lf.extracted.zarr" - ) + f"sorted_stream_{probe_id[-1]}/curated_sa.zarr" + # <<< spikeinterface cache <<< # depth info of probe pixels[stream_id]["surface_depth"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_surface_depth.yaml" + f"{session_name}_{probe_id}_surface_depth.yaml" ) pixels[stream_id]["clustered_channels"] = base_name.with_name( f"{session_name}_{stream_id}_channel_clustering_results.h5" From f3a0d92f914e72d928bf10743d93c155f49a845a Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:48:02 +0100 Subject: [PATCH 355/658] save probe_id as a var to make names more clear --- pixels/ioutils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 49a5e53..987cb07 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -79,6 +79,7 @@ def get_data_files(data_dir, session_name): pixels = {} for r, rec in enumerate(ap_raw): stream_id = rec[-12:-4] + probe_id = stream_id[:-3] # separate recordings by their stream ids if stream_id not in pixels: pixels[stream_id] = { @@ -129,31 +130,31 @@ def get_data_files(data_dir, session_name): # the chance # memmaps for temporary storage pixels[stream_id]["spiked_shuffled_memmap"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_spiked_shuffled.bin" + f"{session_name}_{probe_id}_spiked_shuffled.bin" ) pixels[stream_id]["fr_shuffled_memmap"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_fr_shuffled.bin" + f"{session_name}_{probe_id}_fr_shuffled.bin" ) pixels[stream_id]["shuffled_shape"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_shuffled_shape.json" + f"{session_name}_{probe_id}_shuffled_shape.json" ) pixels[stream_id]["shuffled_index"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_shuffled_index.h5" + f"{session_name}_{probe_id}_shuffled_index.h5" ) pixels[stream_id]["shuffled_columns"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_shuffled_columns.h5" + f"{session_name}_{probe_id}_shuffled_columns.h5" ) # .h5 files pixels[stream_id]["spiked_shuffled"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_spiked_shuffled.h5" + f"{session_name}_{probe_id}_spiked_shuffled.h5" ) pixels[stream_id]["fr_shuffled"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_fr_shuffled.h5" + f"{session_name}_{probe_id}_fr_shuffled.h5" ) # noise in curated units pixels[stream_id]["noisy_units"] = base_name.with_name( - f"{session_name}_{stream_id[:-3]}_noisy_units.yaml" + f"{session_name}_{probe_id}_noisy_units.yaml" ) # old catgt data From 3abf4fd3513e5708bf16eb09f7bf1997c15eb52f Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:50:19 +0100 Subject: [PATCH 356/658] separate motion correction for ap and lfp band --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 157c45d..0bdf48a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -127,7 +127,7 @@ def _preprocess_raw(rec, surface_depth): return cmr -def correct_motion(rec, mc_method="dredge"): +def correct_ap_motion(rec, mc_method="dredge"): """ Correct motion of recording. From 4be7c02b676a4f6da0e453e031e6eb684fdc053a Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:51:56 +0100 Subject: [PATCH 357/658] explicitly define motion estimation method & interpolation dtype --- pixels/pixels_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 0bdf48a..724bc2a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -142,18 +142,27 @@ def correct_ap_motion(rec, mc_method="dredge"): === None """ - logging.info(f"\n\t> correct motion with {mc_method}.") + logging.info(f"\n> Correcting motion with {mc_method}.") + # reduce spatial window size for four-shank + # TODO may 8 2025 "method":"dredge_ap" after it's implemented? estimate_motion_kwargs = { + "method": "decentralized", "win_step_um": 100, "win_margin_um": -150, + "verbose": True, + } + + # make sure recording dtype is float for interpolation + interpolate_motion_kwargs = { + "dtype": np.float32, } mcd = spre.correct_motion( rec, preset=mc_method, estimate_motion_kwargs=estimate_motion_kwargs, - #interpolate_motion_kwargs={'border_mode':'force_extrapolate'}, + interpolate_motion_kwargs=interpolate_motion_kwargs, ) # convert to int16 to save space From 4e5268b82c86210b08913fb9d3a87414cb408ac8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 12:52:48 +0100 Subject: [PATCH 358/658] explicitly define filter direction --- pixels/pixels_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 724bc2a..a91ad01 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -272,9 +272,9 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): ftype: str, filter type. since its posthoc, we use 5th order acausal filter, and takes - second-order sections (SOS) representation of the filter. but more - filters to choose from, e.g., bessel with filter_order=2, presumably - preserves waveform better? see lussac. + second-order sections (SOS) representation of the filter, + forward-backward. but more filters to choose from, e.g., bessel with + filter_order=2, presumably preserves waveform better? see lussac. return === @@ -284,7 +284,10 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): rec, freq_min=freq_min, freq_max=freq_max, + margin_ms=5.0, + filter_order=5, ftype=ftype, + direction="forward-backward", ) return band From 329622d26142acd3790eb2ba883a12c914feff4c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:04:57 +0100 Subject: [PATCH 359/658] add func to whiten recording --- pixels/pixels_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index a91ad01..2405405 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -294,6 +294,17 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): +def whiten(rec): + whitened = spre.whiten( + recording=rec, + dtype=np.float32, + mode="local", + radius_um=240.0, # 16 nearby chans in line with ks4 + ) + + return whitened + + """ Sort spikes with kilosort 4, curate sorting, save sorting analyser to disk, and export results to disk. From ce2d719f87718bec94142c0924c988931a2f868c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:05:38 +0100 Subject: [PATCH 360/658] allows to define an additional rec obj for sorting analyser only --- pixels/pixels_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2405405..5995e46 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -293,7 +293,6 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): return band -def sort_spikes(rec, output, curated_sa_dir, ks_image_path, ks4_params): def whiten(rec): whitened = spre.whiten( recording=rec, @@ -305,6 +304,7 @@ def whiten(rec): return whitened +def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params): """ Sort spikes with kilosort 4, curate sorting, save sorting analyser to disk, and export results to disk. @@ -313,6 +313,8 @@ def whiten(rec): === rec: spikeinterface recording object. + sa_rec: spikeinterface recording object for creating sorting analyser. + output: path object, directory of output. curated_sa_dir: path object, directory to save curated sorting analyser. @@ -338,6 +340,7 @@ def whiten(rec): # sort spikes sorting, recording = _sort_spikes( rec, + sa_rec, output, ks_image_path, ks4_params, @@ -361,14 +364,17 @@ def whiten(rec): return None -def _sort_spikes(rec, output, ks_image_path, ks4_params): +def _sort_spikes_by_group(rec, sa_rec, output, ks_image_path, ks4_params): """ - Sort spikes with kilosort 4. + Sort spikes with kilosort 4 by group/shank. params === rec: spikeinterface recording object. + sa_rec: spikeinterface recording object for creating sorting analyser. + if None, then use the temp_wh.dat from ks output. + output: path object, directory of output. ks_image_path: path object, directory of local kilosort 4 singularity image. From 19d4847be0dd973072f80635098442ffe8ac98d4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:15:14 +0100 Subject: [PATCH 361/658] add func to sort by gorup; use temp_wh for sa if no sa_rec is defined --- pixels/pixels_utils.py | 132 ++++++++++++++++++++++++++++++----------- 1 file changed, 99 insertions(+), 33 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 5995e46..57e7b38 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -387,17 +387,80 @@ def _sort_spikes_by_group(rec, sa_rec, output, ks_image_path, ks4_params): recording: spikeinterface recording object. """ + logging.info("\n> Sorting spikes per shank.") + # run sorter per shank - #sorting = ss.run_sorter_by_property( - # sorter_name='kilosort4', - # recording=rec, - # grouping_property="group", - # folder=output, - # singularity_image=ks_image_path, - # remove_existing_folder=True, - # verbose=True, - # **ks4_params, - #) + sorting = ss.run_sorter_by_property( + sorter_name="kilosort4", + recording=rec, + grouping_property="group", + folder=output, + singularity_image=ks_image_path, + remove_existing_folder=True, + verbose=True, + **ks4_params, + ) + + if not sa_rec: + recs = [] + groups = rec.split_by("group") + for g, group in groups.items(): + ks_preprocessed = se.read_binary( + file_paths=output/f"{g}/sorter_output/temp_wh.dat", + sampling_frequency=group.sampling_frequency, + dtype=np.int16, + num_channels=group.get_num_channels(), + is_filtered=True, + ) + + # attach probe # to ks4 preprocessed recording, from the raw + with_probe = ks_preprocessed.set_probe(group.get_probe()) + # set properties to make sure sorting & sorting sa have all + # probe # properties to form correct rec_attributes, esp + with_probe._properties = group._properties + + # >>> annotations >>> + annotations = group.get_annotation_keys() + annotations.remove("is_filtered") + for ann in annotations: + with_probe.set_annotation( + annotation_key=ann, + value=group.get_annotation(ann), + overwrite=True, + ) + # <<< annotations <<< + recs.append(with_probe) + + recording = si.aggregate_channels(recs) + else: + recording = sa_rec + + return sorting, recording + + +def _sort_spikes(rec, sa_rec, output, ks_image_path, ks4_params): + """ + Sort spikes with kilosort 4. + + params + === + rec: spikeinterface recording object. + + sa_rec: spikeinterface recording object for creating sorting analyser. + + output: path object, directory of output. + + ks_image_path: path object, directory of local kilosort 4 singularity image. + + ks4_params: dict, parameters for kilosort 4. + + return + === + sorting: spikeinterface sorting object. + + recording: spikeinterface recording object. + """ + logging.info("\n> Sorting spikes.") # run sorter sorting = ss.run_sorter( @@ -424,30 +487,33 @@ def _sort_spikes_by_group(rec, sa_rec, output, ks_image_path, ks4_params): # 1. without whitening, peak amplitude should be ~-70mV # 2. with whitening, peak amplitude should be between -1 to 1 - # load ks preprocessed recording for # sorting analyser - ks_preprocessed = se.read_binary( - file_paths=output/"sorter_output/temp_wh.dat", - sampling_frequency=rec.sampling_frequency, - dtype=np.int16, - num_channels=rec.get_num_channels(), - is_filtered=True, - ) - # attach probe # to ks4 preprocessed recording, from the raw - recording = ks_preprocessed.set_probe(rec.get_probe()) - # set properties to make sure sorting & sorting sa have all - # probe # properties to form correct rec_attributes, esp - recording._properties = rec._properties - - # >>> annotations >>> - annotations = rec.get_annotation_keys() - annotations.remove("is_filtered") - for ann in annotations: - recording.set_annotation( - annotation_key=ann, - value=rec.get_annotation(ann), - overwrite=True, + if not sa_rec: + # load ks preprocessed recording for # sorting analyser + ks_preprocessed = se.read_binary( + file_paths=output/"sorter_output/temp_wh.dat", + sampling_frequency=rec.sampling_frequency, + dtype=np.int16, + num_channels=rec.get_num_channels(), + is_filtered=True, ) - # <<< annotations <<< + # attach probe # to ks4 preprocessed recording, from the raw + recording = ks_preprocessed.set_probe(rec.get_probe()) + # set properties to make sure sorting & sorting sa have all + # probe # properties to form correct rec_attributes, esp + recording._properties = rec._properties + + # >>> annotations >>> + annotations = rec.get_annotation_keys() + annotations.remove("is_filtered") + for ann in annotations: + recording.set_annotation( + annotation_key=ann, + value=rec.get_annotation(ann), + overwrite=True, + ) + # <<< annotations <<< + else: + recording = sa_rec return sorting, recording From 257b5c0508033df5a684330454dbaaecd4af2584 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:17:02 +0100 Subject: [PATCH 362/658] add logging --- pixels/pixels_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 57e7b38..4bbd990 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -536,6 +536,8 @@ def _curate_sorting(sorting, recording, output): curated_sa: curated spikeinterface sorting analyser. """ + logging.info("\n> Curating sorting.") + # curate sorter output # remove spikes exceeding recording number of samples sorting = sc.remove_excess_spikes(sorting, recording) @@ -643,6 +645,8 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): === None """ + logging.info("\n> Exporting sorting results.") + # export pre curation report sexp.export_report( sorting_analyzer=sa, From 3bf3ba349a083d6698caa3d851fb222219c4ce6d Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:17:20 +0100 Subject: [PATCH 363/658] compute pca too --- pixels/pixels_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4bbd990..cbc080a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -573,7 +573,7 @@ def _curate_sorting(sorting, recording, output): "template_similarity", "spike_amplitudes", "correlograms", - #"principal_components", # for # phy + "principal_components", # for # phy ] sa.compute(required_extensions, save=True) @@ -585,6 +585,7 @@ def _curate_sorting(sorting, recording, output): max_chan = si.get_template_extremum_channel(sa).values() # get group id for each unit unit_group = group[list(max_chan)] + try: # set unit group as a property for sorting sa.sorting.set_property( key="group", From 4930ca9fc6358ba6d181456b3a1c0077c92b71c2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:18:34 +0100 Subject: [PATCH 364/658] define max_chan even if group is in property; catch index error --- pixels/pixels_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index cbc080a..1e0ab7d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -584,13 +584,18 @@ def _curate_sorting(sorting, recording, output): # get max peak channel for each unit max_chan = si.get_template_extremum_channel(sa).values() # get group id for each unit - unit_group = group[list(max_chan)] try: + unit_group = group[list(max_chan)] + except IndexError: + unit_group = group[sa.channel_ids_to_indices(max_chan)] # set unit group as a property for sorting sa.sorting.set_property( key="group", values=unit_group, ) + else: + # get max peak channel for each unit + max_chan = si.get_template_extremum_channel(sa).values() # calculate quality metrics qms = sqm.compute_quality_metrics(sa) From 7dbe07d292cb1bdb82c469cf6f1de54b13cc9dd8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:21:18 +0100 Subject: [PATCH 365/658] reduce amp threshold --- pixels/pixels_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1e0ab7d..068b507 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -614,11 +614,11 @@ def _curate_sorting(sorting, recording, output): # <<< get depth of units on each shank <<< # remove bad units - #rule = "sliding_rp_violation <= 0.1 & amplitude_median <= -50\ + #rule = "sliding_rp_violation <= 0.1 & amplitude_median <= -40\ # & amplitude_cutoff < 0.05 & sd_ratio < 1.5 & presence_ratio > 0.9\ # & snr > 1.1 & rp_contamination < 0.2 & firing_rate > 0.1" # use the ibl methods, but amplitude_cutoff rather than noise_cutoff - rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -50\ + rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -40\ & presence_ratio > 0.9" good_qms = qms.query(rule) # TODO nov 26 2024 From 3f667ac276748b7d705fd923093688be0a3efc6b Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:21:52 +0100 Subject: [PATCH 366/658] compute pca for phy --- pixels/pixels_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 068b507..6c604ee 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -681,7 +681,6 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): sexp.export_to_phy( sorting_analyzer=curated_sa, output_folder=output/"curated_report/phy", - compute_pc_features=False, copy_binary=False, ) From a49a4d31fe6263f19cf55c6f70870755cc853ebc Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:22:12 +0100 Subject: [PATCH 367/658] make sure to use specific `how` --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 6c604ee..905d023 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1290,7 +1290,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): elif data_type == "spiked": # group values by position and get sum data how = "sum" - grouped_data = group_and_aggregate(trial_data, "position", "sum") + grouped_data = group_and_aggregate(trial_data, "position", how) # reindex into full tunnel length pos_data[trial] = grouped_data.reindex(indices) From 9e15f1adf18c82180aa4b2cff7524962c4dd213a Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:23:27 +0100 Subject: [PATCH 368/658] separate ap & lfp motion correction --- pixels/behaviours/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 746e119..d26975b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -509,7 +509,8 @@ def process_behaviour(self): logging.info("\n> Done!") - def correct_motion(self, mc_method="dredge"): + + def correct_ap_motion(self, mc_method="dredge"): """ Correct motion of recording. @@ -559,6 +560,10 @@ def correct_motion(self, mc_method="dredge"): return None + def correct_lfp_motion(self): + raise NotImplementedError("> Not implemented.") + + def preprocess_raw(self): """ Preprocess full-band raw pixels data. From 7110256e511229bf815d09d14e65521724de603c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:24:03 +0100 Subject: [PATCH 369/658] save ap motion correction per stream --- pixels/behaviours/base.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d26975b..18a3e3c 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -532,26 +532,27 @@ def correct_ap_motion(self, mc_method="dredge"): # get pixels streams streams = self.files["pixels"] - for stream_id, stream_files in streams.items(): - output = self.interim / stream_files["motion_corrected"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + output = self.processed / stream_files["ap_motion_corrected"] if output.exists(): - logging.info(f"\n> Motion corrected {stream_id} loaded.") + logging.info(f"\n> {stream_id} already motion corrected.") + stream_files["ap_motion_corrected"] = si.load(output) continue - # preprocess raw recording - self.preprocess_raw() - - # load preprocessed rec - rec = stream_files["preprocessed"] - logging.info( - f"\n>>>>> Correcting motion for recording from {stream_id} " + f"\n>>>>> Correcting motion for ap band from {stream_id} " f"in total of {self.stream_count} stream(s) with {mc_method}" ) - mcd = xut.correct_motion(rec) + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.correct_ap_motion() - mcd.save( + stream_files["ap_motion_corrected"].save( format="zarr", folder=output, compressor=wv_compressor, From ca184d587bb3bed6e484dee961c181e9d53e8957 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:24:47 +0100 Subject: [PATCH 370/658] implement preprocessing at stream level --- pixels/behaviours/base.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 18a3e3c..2856136 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -578,33 +578,19 @@ def preprocess_raw(self): return === - preprocessed: spikeinterface recording. + None """ - # load raw recording as si recording extractor - self.load_raw_ap() - # get pixels streams streams = self.files["pixels"] - for stream_id, stream_files in streams.items(): - # load raw si rec - rec = stream_files["si_rec"] - logging.info( - f"\n>>>>> Preprocessing data for recording from {stream_id} " - f"in total of {self.stream_count} stream(s)" - ) - - # load brain surface depths - depth_info = load_yaml( - path=self.histology / stream_files["depth_info"], - ) - surface_depths = depth_info["raw_signal_depths"][stream_id] - - # preprocess - stream_files["preprocessed"] = xut.preprocess_raw( - rec, - surface_depths, + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, ) + stream.preprocess_raw() return None From 8d294aaa55aed8b86da6b95057bb84f158202d1b Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:25:13 +0100 Subject: [PATCH 371/658] directly assign extracted ap obj --- pixels/behaviours/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2856136..8d563a9 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -625,8 +625,7 @@ def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): continue # get ap band - ap_file = self.find_file(stream_files["ap_extracted"]) - rec = si.load_extractor(ap_file) + rec = stream_files["ap_extracted"] # detect and localise peaks df = xut.detect_n_localise_peaks(rec) From afc435fcea4b21881bf6b3f5fb7cbd8d78b900e3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:25:43 +0100 Subject: [PATCH 372/658] implement band extraction at stream level --- pixels/behaviours/base.py | 51 +++++++++++---------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 8d563a9..593c69f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -641,45 +641,22 @@ def extract_bands(self, freqs=None): extract data of ap and lfp frequency bands from the raw neural recording data. """ - if freqs == None: - bands = freq_bands - elif isinstance(freqs, str) and freqs in freq_bands.keys(): - bands = {freqs: freq_bands[freqs]} - elif isinstance(freqs, dict): - bands = freqs + # preprocess raw + self.preprocess_raw() streams = self.files["pixels"] - for stream_id, stream_files in streams.items(): - for name, freqs in bands.items(): - output = self.processed / stream_files[f"{name}_extracted"] - if output.exists(): - logging.info(f"\n> {name} bands from {stream_id} loaded.") - continue - - # preprocess raw data - self.preprocess_raw() - - logging.info( - f"\n>>>>> Extracting {name} bands from {self.name} " - f"{stream_id} in total of {self.stream_count} stream(s)" - ) - - # load preprocessed - rec = stream_files["preprocessed"] - - # do bandpass filtering - extracted = xut.extract_band( - rec, - freq_min=freqs[0], - freq_max=freqs[1], - ) - - # write to disk - extracted.save( - format="zarr", - folder=output, - compressor=wv_compressor, - ) + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + logging.info( + f"\n>>>>> Extracting bands from {self.name} " + f"{stream_id} in total of {self.stream_count} stream(s)" + ) + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.extract_bands(freqs) """ if self._lag[rec_num] is None: From 28e6fdf57104215bcec5f2dd292b13b290b7c113 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:26:19 +0100 Subject: [PATCH 373/658] update doc --- pixels/behaviours/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 593c69f..82341dd 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -748,9 +748,9 @@ def run_catgt(self, CatGT_app=None, args=None) -> None: def load_raw_ap(self): """ - Write a function to load and concatenate raw recording files for each - stream (i.e., probe), so that data from all runs of the same probe can - be preprocessed and sorted together. + Load and concatenate raw recording files for each stream (i.e., probe), + so that data from all runs of the same probe can be preprocessed and + sorted together. """ # if multiple runs for the same probe, concatenate them streams = self.files["pixels"] From 6d464fbf36911d8d5935fcaa25f7fcc4b18e2049 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:26:35 +0100 Subject: [PATCH 374/658] whiten recordings --- pixels/behaviours/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 82341dd..6f491c0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -774,6 +774,21 @@ def load_raw_ap(self): return None + def whiten_ap(self): + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + logging.info(f"\n> Whitening {stream_id}.") + stream.whiten_ap() + + return None + + def sort_spikes(self, mc_method="dredge"): """ Run kilosort spike sorting on raw spike data. From d6905f8a6d0e6958daa44c2268cfa32f48b72540 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:26:59 +0100 Subject: [PATCH 375/658] use load instead to be future-proof --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 6f491c0..3c0dced 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1740,7 +1740,7 @@ def select_units( stream_files = self.files["pixels"]["imec0.ap"] sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser - temp_sa = si.load_sorting_analyzer(sa_dir) + temp_sa = si.load(sa_dir) # remove noisy units try: From ec2b16027e32d0d6c5de755560b532ff153f2918 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:27:41 +0100 Subject: [PATCH 376/658] define ks4 image version in constants --- pixels/behaviours/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 3c0dced..f89dac8 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -800,8 +800,7 @@ def sort_spikes(self, mc_method="dredge"): (as of jan 2025, dredge performs better than ks motion correction.) "ks": do motion correction with kilosort. """ - ks_image_path = self.interim.parent/"ks4-0-30_with_wavpack.sif" - #ks_image_path = self.interim.parent/"ks4-0-18_with_wavpack.sif" + ks_image_path = self.interim.parent / ks4_image_name if not ks_image_path.exists(): raise PixelsError("Have you craeted Singularity image for sorting?") From 999c3d3b713555c54117d8688d5fd4246c5400d9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:30:31 +0100 Subject: [PATCH 377/658] only correct motion if do not use ks motion correction --- pixels/behaviours/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index f89dac8..d7c4aa0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -805,13 +805,13 @@ def sort_spikes(self, mc_method="dredge"): if not ks_image_path.exists(): raise PixelsError("Have you craeted Singularity image for sorting?") - # preprocess and motion correct raw - self.correct_motion(mc_method) - if mc_method == "ks": ks_mc = True else: ks_mc = False + # preprocess and motion correct raw, also whiten + self.correct_ap_motion(mc_method) + self.whiten_ap() # set ks4 parameters ks4_params = { From c70eea9e95974cf2ac0b8e5b42957783a5d3ebd8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:31:30 +0100 Subject: [PATCH 378/658] params change with whether using ks motion correction --- pixels/behaviours/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d7c4aa0..27b114a 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -815,9 +815,10 @@ def sort_spikes(self, mc_method="dredge"): # set ks4 parameters ks4_params = { - "do_correction": ks_mc, "do_CAR": False, # do not common average reference - "save_preprocessed_copy": True, # save ks4 preprocessed data + "skip_kilosort_preprocessing": not ks_mc, + "do_correction": ks_mc, + "save_preprocessed_copy": ks_mc, # save ks4 preprocessed data } streams = self.files["pixels"] From 7b4fb637c45cd3a411183211a89bb627fd14c772 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:32:15 +0100 Subject: [PATCH 379/658] only print if sa exists --- pixels/behaviours/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 27b114a..2fd3d5b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -825,10 +825,8 @@ def sort_spikes(self, mc_method="dredge"): for stream_num, (stream_id, stream_files) in enumerate(streams.items()): # check if already sorted and exported sa_dir = self.processed / stream_files["sorting_analyser"] - if not sa_dir.exists(): - logging.info(f"\n> {self.name} {stream_id} not sorted/exported.") - else: - logging.info("\n> Already sorted and exported, next session.") + if sa_dir.exists(): + logging.info("\n> Already sorted and exported, next stream.") continue # get catgt directory From b23bcbba2d8b6dce55c0ef61f3aeabcbe73d27bd Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 13:32:33 +0100 Subject: [PATCH 380/658] implement spike sorting at stream level --- pixels/behaviours/base.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2fd3d5b..c71d8f7 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -840,25 +840,18 @@ def sort_spikes(self, mc_method="dredge"): else: output = self.processed / f"sorted_stream_cat_{stream_num}" - # load rec - if ks_mc: - # preprocess raw recording - self.preprocess_raw() - rec = stream_files["preprocessed"] - else: - rec_dir = self.find_file(stream_files["motion_corrected"]) - rec = si.load_extractor(rec_dir) - - # move current working directory to interim - os.chdir(self.interim) - - # sort spikes and save sorting analyser to disk - xut.sort_spikes( - rec=rec, - output=output, - curated_sa_dir=sa_dir, + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + stream.sort_spikes( + ks_mc=ks_mc, ks_image_path=ks_image_path, ks4_params=ks4_params, + output=output, + sa_dir=sa_dir, ) return None From 0d53cf4ec3555131c5f309df96198c78235334f2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 14:34:10 +0100 Subject: [PATCH 381/658] output whitened recording dtype to int16 and scale it to have sd 200 --- pixels/pixels_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 905d023..7853544 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -296,7 +296,8 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): def whiten(rec): whitened = spre.whiten( recording=rec, - dtype=np.float32, + dtype=np.int16, + int_scale=200, # scale traces value to sd of 200, in line with ks4 mode="local", radius_um=240.0, # 16 nearby chans in line with ks4 ) From 71e75f96b30561bc4fc35d88d11361543fc4c8aa Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 14 May 2025 14:40:45 +0100 Subject: [PATCH 382/658] make sure preprocessed dtype being int16 in common_reference --- pixels/pixels_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7853544..2fe3353 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -79,12 +79,6 @@ def preprocess_raw(rec, surface_depths): # preprocess preprocessed = _preprocess_raw(rec, surface_depth) - # NOTE jan 16 2025: - # BUG: cannot set dtype back to int16, units from ks4 will have - # incorrect amp & loc - if not preprocessed.dtype == np.dtype("int16"): - preprocessed = spre.astype(preprocessed, dtype=np.int16) - return preprocessed @@ -119,9 +113,9 @@ def _preprocess_raw(rec, surface_depth): print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") print("\t> step 3: do common median referencing.") - # NOTE: dtype will be converted to float32 during motion correction cmr = spre.common_reference( rec_clean, + dtype=np.int16, # make sure output is int16 ) return cmr From 6149e2ac63df08171f829a5f3f8712503c607816 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 11:02:13 +0100 Subject: [PATCH 383/658] no whitening, and always do ks preprocessing --- pixels/behaviours/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index c71d8f7..785faf1 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -809,16 +809,17 @@ def sort_spikes(self, mc_method="dredge"): ks_mc = True else: ks_mc = False - # preprocess and motion correct raw, also whiten + # preprocess and motion correct raw self.correct_ap_motion(mc_method) - self.whiten_ap() + # XXX: no whitening + #self.whiten_ap() # set ks4 parameters ks4_params = { "do_CAR": False, # do not common average reference - "skip_kilosort_preprocessing": not ks_mc, + "skip_kilosort_preprocessing": False, "do_correction": ks_mc, - "save_preprocessed_copy": ks_mc, # save ks4 preprocessed data + "save_preprocessed_copy": True, # save ks4 preprocessed data } streams = self.files["pixels"] From b47a819141a0e52f70d935d76fce274e06016343 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 11:02:31 +0100 Subject: [PATCH 384/658] add notes --- pixels/behaviours/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 785faf1..0342e78 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1733,6 +1733,9 @@ def select_units( sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser temp_sa = si.load(sa_dir) + # NOTE: this load gives warning when using temp_wh.dat to build + # sorting analyser, and to load sa it will be loading binary + # recording obj, and that checks version # remove noisy units try: From 70820a1a01b9669b686ef3c17d60d1512630090d Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 11:02:58 +0100 Subject: [PATCH 385/658] add per shank arg --- pixels/pixels_utils.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2fe3353..43cd456 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -299,7 +299,8 @@ def whiten(rec): return whitened -def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params): +def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params, + per_shank=False): """ Sort spikes with kilosort 4, curate sorting, save sorting analyser to disk, and export results to disk. @@ -333,13 +334,24 @@ def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params): # si.aggregate_channels... # sort spikes - sorting, recording = _sort_spikes( - rec, - sa_rec, - output, - ks_image_path, - ks4_params, - ) + if np.unique(rec.get_channel_groups()).size > 1 and per_shank: + # per shank + sorting, recording = _sort_spikes_by_group( + rec, + sa_rec, + output, + ks_image_path, + ks4_params, + ) + else: + # all together + sorting, recording = _sort_spikes( + rec, + sa_rec, + output, + ks_image_path, + ks4_params, + ) # curate sorting sa, curated_sa = _curate_sorting( From 8150d9e33695d90832662f465cc07429c6f0af96 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:27:19 +0100 Subject: [PATCH 386/658] separate export to phy as an option --- pixels/pixels_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 43cd456..d1dd5eb 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -684,10 +684,14 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): folder=curated_sa_dir, ) + return None + + +def export_sa_to_phy(path, sa): # export to phy for additional manual curation if needed sexp.export_to_phy( - sorting_analyzer=curated_sa, - output_folder=output/"curated_report/phy", + sorting_analyzer=sa, + output_folder=path/"phy", copy_binary=False, ) From c5893e4dbc8f46ff72506dc42e8fa35ba037f0bf Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:29:13 +0100 Subject: [PATCH 387/658] export whitened data as float --- pixels/pixels_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d1dd5eb..36c285c 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -290,8 +290,9 @@ def extract_band(rec, freq_min, freq_max, ftype="butter"): def whiten(rec): whitened = spre.whiten( recording=rec, - dtype=np.int16, - int_scale=200, # scale traces value to sd of 200, in line with ks4 + dtype=np.float32, + #dtype=np.int16, + #int_scale=200, # scale traces value to sd of 200, in line with ks4 mode="local", radius_um=240.0, # 16 nearby chans in line with ks4 ) From ed578dc0f75fbc841d6427cf2891638dea9664c6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:30:15 +0100 Subject: [PATCH 388/658] update todo --- pixels/pixels_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 36c285c..cef220e 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -492,8 +492,7 @@ def _sort_spikes(rec, sa_rec, output, ks_image_path, ks4_params): # 2. still use the temp_wh.dat from ks4, but check how ks4 handles amplitude # and the unit of amplitude, correct it # WHAT TO ACHIEVE: - # 1. without whitening, peak amplitude should be ~-70mV - # 2. with whitening, peak amplitude should be between -1 to 1 + # 1. without whitening, peak amplitude median should be ~-70uV if not sa_rec: # load ks preprocessed recording for # sorting analyser From f84270f5d802ca5544162ec26d11138d54891020 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:55:11 +0100 Subject: [PATCH 389/658] update notes --- pixels/pixels_utils.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index cef220e..3223c29 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -481,19 +481,10 @@ def _sort_spikes(rec, sa_rec, output, ks_image_path, ks4_params): **ks4_params, ) - # TODO apr 10 2025: - # since this file is whitened, the amplitude of the signal is NOT the same - # as the original, and this might cause issue in calculating signal - # amplitude in spikeinterface. cuz in ks4 output, units amplitude is between - # 0-315, but in si it's between -4000 to 4000. - # POTENTIAL SOLUTIONS: - # 1. do what chris does, make another preprocessed recording just to build - # the sorting analyser, or - # 2. still use the temp_wh.dat from ks4, but check how ks4 handles amplitude - # and the unit of amplitude, correct it - # WHAT TO ACHIEVE: - # 1. without whitening, peak amplitude median should be ~-70uV - + # NOTE: may 20 2025 + # build sa with non-whitened preprocessed rec gives amp between 0-250uV, + # which makes sense, and quality metric amp_median is comparable across + # recordings if not sa_rec: # load ks preprocessed recording for # sorting analyser ks_preprocessed = se.read_binary( From c765b6c15ebef8440ccb83249a8b4dd58e26ba61 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:55:25 +0100 Subject: [PATCH 390/658] put to_phy as an option --- pixels/pixels_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3223c29..7cd78d9 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -633,7 +633,8 @@ def _curate_sorting(sorting, recording, output): return sa, curated_sa -def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): +def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir, + to_phy=False): """ Export sorting analyser to disk. @@ -675,6 +676,9 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir): folder=curated_sa_dir, ) + if to_phy: + export_sa_to_phy(output, sa) + return None From 0198720cdb939b8c0fe169a75d03dc7a5c032ce5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:55:58 +0100 Subject: [PATCH 391/658] ignore positions after 600 --- pixels/pixels_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7cd78d9..8551137 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1252,10 +1252,10 @@ def _get_vr_positional_neural_data(positions, data_type, data): # get the starting index for each trial (column) starts = positions.iloc[0, :].astype(int) # create position indices - indices = np.arange(0, TUNNEL_RESET+2) + indices = np.arange(0, TUNNEL_RESET+1) # create occupancy array for trials occupancy = np.full( - (TUNNEL_RESET+2, positions.shape[1]), + (TUNNEL_RESET+1, positions.shape[1]), np.nan, ) From 2bcc29911b9a9e9af9ab7ac7199a4ce38b0756c3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:56:27 +0100 Subject: [PATCH 392/658] floor positions to avoid overflowing --- pixels/pixels_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 8551137..74816ad 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1264,15 +1264,11 @@ def _get_vr_positional_neural_data(positions, data_type, data): # get trial position trial_pos = positions[trial].dropna() - # floor pre reward zone and end ceil post zone end - trial_pos = trial_pos.apply( - lambda x: np.floor(x) if x <= ZONE_END else np.ceil(x) - ) - # set to int - trial_pos = trial_pos.astype(int) + # floor position and set to int + trial_pos = trial_pos.apply(lambda x: np.floor(x)).astype(int) # exclude positions after tunnel reset - trial_pos = trial_pos[trial_pos <= TUNNEL_RESET+1] + trial_pos = trial_pos[trial_pos <= TUNNEL_RESET] # get firing rates for current trial of all units trial_data = data.xs( From 5a4372c6c3058c79a1e19fa5e78c862f0fb969bf Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 12:57:10 +0100 Subject: [PATCH 393/658] organise sorting of all levels --- pixels/pixels_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 74816ad..9390912 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1321,7 +1321,12 @@ def _get_vr_positional_neural_data(positions, data_type, data): names=["start", "unit", "trial"], ) pos_data.columns = new_cols - # sort by unit - pos_data = pos_data.sort_index(level="unit", axis=1) + + # sort by unit, starting position, and then trial + pos_data = pos_data.sort_index( + axis=1, + level=["unit", "start", "trial"], + ascending=[True, False, True], + ) return pos_data, occupancy From 649adae2be91c0f81700a175bc0d56a9ee392dc3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 20 May 2025 13:00:56 +0100 Subject: [PATCH 394/658] remove sorting separately notes --- pixels/pixels_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9390912..0fdbfbd 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -326,13 +326,6 @@ def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params, recording: spikeinterface recording object. """ - # NOTE: jan 30 2025 do we sort shanks separately??? - # if shanks are sorted separately, they will have separate sorter output, we - # will have to build an analyser for each group... - # maybe easier to just run all shanks together? - # the only way to concatenate four temp.dat and only create one sorting - # analyser is to read temp_wh.dat, set channels separately from raw, and - # si.aggregate_channels... # sort spikes if np.unique(rec.get_channel_groups()).size > 1 and per_shank: From d5675537dde66384db4749d9f832e57bcc791d3d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 13:23:56 +0100 Subject: [PATCH 395/658] get max trial count for trials in tunnel --- pixels/behaviours/virtual_reality.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 34abf50..f3caed4 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -221,7 +221,7 @@ def _extract_action_labels(self, vr, vr_data): # label trial starts action_labels[trial_starts, 1] += Events.trial_start - if not trial_starts.size == vr_data.trial_count.max(): + if not trial_starts.size == vr_data[in_tunnel].trial_count.max(): raise PixelsError(f"Number of trials does not equal to " "{vr_data.trial_count.max()}.") # NOTE: if trial starts at 0, the first position_in_tunnel value will From 6ef14c3509a3ae380b360a4d38dcc150376f767b Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:04:50 +0100 Subject: [PATCH 396/658] use logging --- pixels/experiment.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 138166a..10f7be8 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -12,6 +12,7 @@ from pixels import ioutils from pixels.error import PixelsError +from pixels.configs import * class Experiment: @@ -224,7 +225,15 @@ def select_units(self, *args, **kwargs): units = {} for i, session in enumerate(self.sessions): - units[session.name] = session.select_units(*args, **kwargs) + name = session.name + selected = session.select_units(*args, **kwargs) + if len(selected) == 0: + logging.warning( + f"\n> {name} does not have units in {selected}, " + "skip." + ) + else: + units[name] = selected return units From 58f13551b8da62c6af553d22a58cc8b4208ff2f2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:05:32 +0100 Subject: [PATCH 397/658] only loop through sessions if there are units found in that session --- pixels/experiment.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 10f7be8..9ae7669 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -545,21 +545,22 @@ def get_binned_trials(self, *args, units=None, **kwargs): Get binned firing rate and spike count for aligned vr trials. Check behaviours.base.Behaviour.get_binned_trials for usage information. """ + session_names = [session.name for session in self.sessions] trials = {} - for i, session in enumerate(self.sessions): - name = session.name - result = None - if units: - if units[name]: - result = session.get_binned_trials( - *args, - units=units[name], - **kwargs, - ) - else: + if not units is None: + for name in units.keys(): + session = self.sessions[session_names.index(name)] + trials[name] = session.get_binned_trials( + *args, + units=units[name], + **kwargs, + ) + else: + for i, session in enumerate(self.sessions): + name = session.name result = session.get_binned_trials(*args, **kwargs) - if result is not None: - trials[name] = result + if not result is None: + trials[name] = result level_names = ["session", "stream", "unit", "trial"] bin_fr = ioutils.get_aligned_data_across_sessions( From 75bd332cf4038efb04e83cfdfa5603ea74ab4a68 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:06:02 +0100 Subject: [PATCH 398/658] synchronise pixels with vr for multiple sessions --- pixels/experiment.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index 9ae7669..11c3173 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -585,3 +585,14 @@ def get_binned_trials(self, *args, units=None, **kwargs): } return df + + + def sync_vr(self, vr): + """ + Synchronise virtual reality data with pixels streams. + """ + trials = {} + for i, session in enumerate(self.sessions): + session.sync_vr(vr) + + return None From dcb982333cbe843fcdcf194a8429e9057a496e2a Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:08:04 +0100 Subject: [PATCH 399/658] store sorting analyser directly to disk --- pixels/pixels_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 0fdbfbd..939ffd6 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -550,6 +550,10 @@ def _curate_sorting(sorting, recording, output): sa = si.create_sorting_analyzer( sorting=sorting, recording=recording, + sparse=True, + format="zarr", + folder=output/"sa.zarr", + overwrite=True, ) # calculate all extensions BEFORE further steps @@ -651,12 +655,6 @@ def _export_sorting_analyser(sa, curated_sa, output, curated_sa_dir, output_folder=output/"report", ) - # save pre-curated analyser to disk - sa.save_as( - format="zarr", - folder=output/"sa.zarr", - ) - # export curated report sexp.export_report( sorting_analyzer=curated_sa, From 5f4f18368e9dde5d1cb417d466b1e4eab35f3bb8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:13:19 +0100 Subject: [PATCH 400/658] add boolean flag to sort per shank --- pixels/pixels_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 939ffd6..0a09593 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -320,6 +320,10 @@ def sort_spikes(rec, sa_rec, output, curated_sa_dir, ks_image_path, ks4_params, ks4_params: dict, parameters for kilosort 4. + per_shank: bool, whether to sort recording per shank. + Default: False (as of may 2025, sort shanks separately by ks4 gives less + units) + return === sorting: spikeinterface sorting object. From c081b6db6a35b4732ece3a25c4c28ddecb12b05a Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:19:07 +0100 Subject: [PATCH 401/658] remove redundant hashtag --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 0a09593..68c62b1 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -572,7 +572,7 @@ def _curate_sorting(sorting, recording, output): "template_similarity", "spike_amplitudes", "correlograms", - "principal_components", # for # phy + "principal_components", # for phy ] sa.compute(required_extensions, save=True) From 1d750f00301a6f996cf0a95727215492ceee352e Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:19:30 +0100 Subject: [PATCH 402/658] change default motion correction method for sorting recording to ks --- pixels/behaviours/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0342e78..84d01dd 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -789,16 +789,18 @@ def whiten_ap(self): return None - def sort_spikes(self, mc_method="dredge"): + def sort_spikes(self, mc_method="ks"): """ Run kilosort spike sorting on raw spike data. params === mc_method: str, motion correction method. - Default: "dredge". - (as of jan 2025, dredge performs better than ks motion correction.) + Default: "ks". + (as of may 2025, feeding ap dredged recording to kilosort 4 + gives much less unit, so just let ks4 does its thing.) "ks": do motion correction with kilosort. + "dredge": do motion correction with dredge on ap band. """ ks_image_path = self.interim.parent / ks4_image_name From a4acc66a23adfff5f130ae5c6802e0381ad5f956 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:20:16 +0100 Subject: [PATCH 403/658] do separate motion correction anyways to build sorting analyser --- pixels/behaviours/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 84d01dd..d5f5d6b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -807,12 +807,13 @@ def sort_spikes(self, mc_method="ks"): if not ks_image_path.exists(): raise PixelsError("Have you craeted Singularity image for sorting?") + # ap band motion correct ONLY for building sorting analyser + self.correct_ap_motion() + if mc_method == "ks": ks_mc = True else: ks_mc = False - # preprocess and motion correct raw - self.correct_ap_motion(mc_method) # XXX: no whitening #self.whiten_ap() From b37d6b67aec8661648554d2a7326a7c158457eea Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 22 May 2025 15:20:56 +0100 Subject: [PATCH 404/658] use ap motion corrected rec to build sorting analyser; do not use whitened rec for now --- pixels/stream.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 22753d4..f4671c6 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -666,14 +666,12 @@ def sort_spikes(self, ks_mc, ks4_params, ks_image_path, output, sa_dir): if ks_mc: self.preprocess_raw() rec = self.files["preprocessed"] - sa_rec = None + sa_rec = self.files["ap_motion_corrected"] else: - # whiten ap band and feed to ks - rec = self.files["ap_whitened"] - # use non-whitened recording for sorting analyser - #sa_rec = self.files["ap_motion_corrected"] - sa_rec = rec - # TODO may 13 2025: test building sa with whitened + # XXX: as of may 2025, whiten ap band and feed to ks reduce units! + #rec = self.files["ap_whitened"] + # use non-whitened recording for ks4 and sorting analyser + sa_rec = rec = self.files["ap_motion_corrected"] # sort spikes and save sorting analyser to disk xut.sort_spikes( From c9f71e22727b1147192f52e753450b42710768f0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 30 May 2025 16:41:41 +0100 Subject: [PATCH 405/658] put end event as the last argument --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d5f5d6b..63d88ac 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2655,8 +2655,8 @@ def get_positional_data( units=units, # NOTE: put units first! label=label, event=event, - end_event=end_event, sigma=sigma, + end_event=end_event, ) return output From 316e31eed4c64eef81c58761f45510e754a23930 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 30 May 2025 16:42:18 +0100 Subject: [PATCH 406/658] implement func to get power spectral density --- pixels/behaviours/base.py | 35 +++++++++++++++++++++++++++++ pixels/pixels_utils.py | 17 +++++++++++++++ pixels/stream.py | 46 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 63d88ac..db6bb51 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2821,3 +2821,38 @@ def sync_vr(self, vr): stream.sync_vr(vr) return None + + + def get_spatial_psd( + self, label, event, end_event=None, sigma=None, units=None, + ): + """ + Get spatial power spectral density of selected units. + """ + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + output = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting spatial PSD of {units} units in " + f"<{label}> trials." + ) + output[stream_id] = stream.get_spatial_psd( + units=units, # NOTE: put units first! + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: put units last! + ) + + return output diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 68c62b1..d8e3d0e 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1325,3 +1325,20 @@ def _get_vr_positional_neural_data(positions, data_type, data): ) return pos_data, occupancy + + +def get_spatial_psd(pos_fr): + def _compute_psd(col): + x = col.dropna().values.squeeze() + f, psd = math_utils.estimate_power_spectrum(x) + # remove 0 to avoid infinity + f = f[1:] + psd = psd[1:] + ser = pd.Series(psd, index=f, name=col.name) + ser.index.name = "frequency" + + return ser + + psd = pos_fr.apply(_compute_psd, axis=0) + + return psd diff --git a/pixels/stream.py b/pixels/stream.py index f4671c6..3dff052 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -716,3 +716,49 @@ def save_spike_chance(self, spiked, sigma): ) return None + + + def get_spatial_psd( + self, label, event, end_event=None, sigma=None, units=None, + ): + """ + Get spatial power spectral density of selected units. + """ + # NOTE: order of args matters for loading the cache! + # always put units first, cuz it is like that in + # experiemnt.align_trials, otherwise the same cache cannot be loaded + + from vision_in_darkness.constants import landmarks + # get aligned firing rates and positions + trials = self.get_positional_data( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + # get positional fr + pos_fr = trials["pos_fr"] + + starts = pos_fr.columns.get_level_values("start").unique() + psds = {} + for s, start in enumerate(starts): + data = pos_fr.xs(start, level="start", axis=1) + # remove black wall and post last landmark + cropped = data.loc[landmarks[0]:landmarks[-1], :] + # TODO may 30 2025: + # only remove 60cm black wall in light, remove first 50cm of tunnel + # anyways in dark! + + # get power spectral density + #psds[start] = xut.get_spatial_psd(data) + psds[start] = xut.get_spatial_psd(cropped) + + psd_df = pd.concat( + psds, + names=["start","frequency"], + ) + # NOTE: all trials will appear in all starts, but their values will be + # all nan in other starts, so remember to dropna(axis=1)! + + return psd_df From 72d7b12815a1304464ba2f199fea44351b41589e Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 30 May 2025 16:42:41 +0100 Subject: [PATCH 407/658] add unit --- pixels/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/constants.py b/pixels/constants.py index 9ee37c3..4c55af1 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -3,7 +3,7 @@ """ import numpy as np -SAMPLE_RATE = 2000 +SAMPLE_RATE = 2000 # Hz freq_bands = { "ap":[300, 9000], From a73de72a0f8052591615b09aaacced15d2562ef3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 30 May 2025 16:42:48 +0100 Subject: [PATCH 408/658] add completed trials category --- pixels/behaviours/virtual_reality.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index f3caed4..00f8c5e 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -68,6 +68,10 @@ class ActionLabels: rewarded_dark = triggered_dark | auto_dark | reinf_dark given_light = default_light | auto_light | reinf_light given_dark = auto_dark | reinf_dark + # no punishment, i.e., no missing data + completed_light = miss_light | triggered_light | default_light\ + | auto_light | reinf_light + completed_dark = miss_dark | triggered_dark | auto_dark | reinf_dark class Events: From 031bea833899e7dbc04abfd0b7fbc8d10710a3f5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 2 Jun 2025 11:17:54 +0100 Subject: [PATCH 409/658] import the whole math_utils --- pixels/pixels_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d8e3d0e..bb2e76d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -21,7 +21,7 @@ from pixels.error import PixelsError from pixels.configs import * -from common_utils.math_utils import random_sampling, group_and_aggregate +from common_utils import math_utils from common_utils.file_utils import init_memmap, read_hdf5 def load_raw(paths, stream_id): @@ -1287,7 +1287,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): elif data_type == "spiked": # group values by position and get sum data how = "sum" - grouped_data = group_and_aggregate(trial_data, "position", how) + grouped_data = math_utils.group_and_aggregate(trial_data, "position", how) # reindex into full tunnel length pos_data[trial] = grouped_data.reindex(indices) From 28ff1aa79ab88f09b4993c8126e3a46484f3c6ca Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 2 Jun 2025 19:29:50 +0100 Subject: [PATCH 410/658] add pre_dark_len event --- pixels/behaviours/virtual_reality.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 00f8c5e..23d6036 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -93,19 +93,20 @@ class Events: trial_end = 1 << 9 # 512 # positional events - black = 1 << 10 # 0 - 60 cm - wall = 1 << 11 # in between landmarks - landmark1 = 1 << 12 # 110 - 130 cm - landmark2 = 1 << 13 # 190 - 210 cm - landmark3 = 1 << 14 # 270 - 290 cm - landmark4 = 1 << 15 # 350 - 370 cm - landmark5 = 1 << 16 # 430 - 450 cm - reward_zone = 1 << 17 # 460 - 495 cm + pre_dark_end = 1 << 10 # 50 cm + black = 1 << 12 # 0 - 60 cm + wall = 1 << 12 # in between landmarks + landmark1 = 1 << 13 # 110 - 130 cm + landmark2 = 1 << 14 # 190 - 210 cm + landmark3 = 1 << 15 # 270 - 290 cm + landmark4 = 1 << 16 # 350 - 370 cm + landmark5 = 1 << 17 # 430 - 450 cm + reward_zone = 1 << 18 # 460 - 495 cm # sensors - valve_open = 1 << 18 # 262144 - valve_closed = 1 << 19 # 524288 - licked = 1 << 20 # 1048576 + valve_open = 1 << 19 # 262144 + valve_closed = 1 << 20 # 524288 + licked = 1 << 21 # 1048576 #run_start = 1 << 12 #run_stop = 1 << 13 From d7a250727bbd7f782e612b53de79971c104ef04c Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 2 Jun 2025 19:30:13 +0100 Subject: [PATCH 411/658] add note & clarify format --- pixels/behaviours/virtual_reality.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 23d6036..d2cbda9 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -238,6 +238,12 @@ def _extract_action_labels(self, vr, vr_data): action_labels[light_off, 1] += Events.light_off # <<<< light <<<< + # NOTE: dark trials should in theory have EQUAL index pre_dark_end_idx + # and dark_on, BUT! after interpolation, some dark trials have their + # dark onset earlier than expected, those are corrected to the first + # expected position, they will have the same index as pre_dark_end_idx. + # others will not since their world_index change later than expected. + # NOTE: if dark trial is aborted, light tunnel only turns off once; but # if it is a reward is dispensed, light tunnel turns off twice @@ -268,6 +274,7 @@ def _extract_action_labels(self, vr, vr_data): # TODO jun 27 2024 positional events and valve events needs mapping + # NOTE: AL remove pre_dark_len + 10cm of his data logging.info("\n>> Mapping vr action times...") # >>>> map reward types >>>> @@ -281,6 +288,7 @@ def _extract_action_labels(self, vr, vr_data): trial_idx = np.where(of_trial)[0] # get start index of current trial start_idx = trial_idx[np.isin(trial_idx, trial_starts)] + # find where is non-zero reward type in current trial reward_typed = vr_data[of_trial & reward_not_none] # get trial type of current trial @@ -288,6 +296,8 @@ def _extract_action_labels(self, vr, vr_data): # get name of trial type in string trial_type_str = trial_type_lookup.get(trial_type).lower() + # >>>> map reward types >>>> + # >>>> punished >>>> if (reward_typed.size == 0)\ & (vr_data[of_trial & in_white].size != 0): @@ -338,7 +348,8 @@ def _extract_action_labels(self, vr, vr_data): ) action_labels[valve_closed_idx, 1] += Events.valve_closed # <<<< non aborted, valve only <<<< - # <<<< map reward types <<<< + + # <<<< map reward types <<<< # put pixels timestamps in the third column action_labels = np.column_stack((action_labels, vr_data.index.values)) From 9528224d203963c7038e84ba3bc70a0b3abd73e5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 2 Jun 2025 19:31:15 +0100 Subject: [PATCH 412/658] map end of pre dark visible tunnel --- pixels/behaviours/virtual_reality.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index d2cbda9..3cf1abd 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -275,6 +275,31 @@ def _extract_action_labels(self, vr, vr_data): # TODO jun 27 2024 positional events and valve events needs mapping # NOTE: AL remove pre_dark_len + 10cm of his data + # get starting positions of all trials + start_pos = vr_data[in_tunnel].groupby( + "trial_count" + )["position_in_tunnel"].first() + + # plus pre dark + end_pre_dark = start_pos + vr.pre_dark_len + + def _first_post_pre_dark(df): + trial = df.name + pre_dark_end = end_pre_dark.loc[trial] + # mask and pick the first index + mask = df['position_in_tunnel'] >= pre_dark_end + if not mask.any(): + return None + return df.index[mask].min() + + pre_dark_end_t = vr_data[in_tunnel].groupby("trial_count").apply( + _first_post_pre_dark + ) + pre_dark_end_idx = vr_data.index.get_indexer( + pre_dark_end_t.dropna().astype(int) + ) + action_labels[pre_dark_end_idx, 1] += Events.pre_dark_end + # <<<< Event: end of pre dark length <<<< logging.info("\n>> Mapping vr action times...") # >>>> map reward types >>>> From da167d69b02f0927e1158b738dc10a34fd6088d4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 2 Jun 2025 19:31:39 +0100 Subject: [PATCH 413/658] add todo --- pixels/behaviours/virtual_reality.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 3cf1abd..a957a7a 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -274,6 +274,7 @@ def _extract_action_labels(self, vr, vr_data): # TODO jun 27 2024 positional events and valve events needs mapping + # >>>> Event: end of pre dark length >>>> # NOTE: AL remove pre_dark_len + 10cm of his data # get starting positions of all trials start_pos = vr_data[in_tunnel].groupby( @@ -300,9 +301,16 @@ def _first_post_pre_dark(df): ) action_labels[pre_dark_end_idx, 1] += Events.pre_dark_end # <<<< Event: end of pre dark length <<<< + + # >>>> Event: reward zone >>>> + assert 0 + # TODO jun 2 2025: + # do positional event mapping + vr_data.position_in_tunnel + # <<<< Event: reward zone <<<< + logging.info("\n>> Mapping vr action times...") - # >>>> map reward types >>>> # get non-zero reward types reward_not_none = (vr_data.reward_type != Outcomes.NONE) From 8b08373d6ffbe99de9c2ef1a91129493a58af9ae Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 2 Jun 2025 19:31:59 +0100 Subject: [PATCH 414/658] takes event label to check what is the actual starting position --- pixels/pixels_utils.py | 13 ++++++++++--- pixels/stream.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index bb2e76d..aa0c725 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1186,7 +1186,7 @@ def correct_group_id(rec): return group_ids -def get_vr_positional_data(trial_data): +def get_vr_positional_data(event, trial_data): """ Get positional firing rate and spike count for VR behaviour. @@ -1200,11 +1200,13 @@ def get_vr_positional_data(trial_data): data in 1cm resolution. """ pos_fr, occupancy = _get_vr_positional_neural_data( + event=event, positions=trial_data["positions"], data_type="spike_rate", data=trial_data["fr"], ) pos_fc, _ = _get_vr_positional_neural_data( + event=event, positions=trial_data["positions"], data_type="spiked", data=trial_data["spiked"], @@ -1213,7 +1215,7 @@ def get_vr_positional_data(trial_data): return {"pos_fr": pos_fr, "pos_fc": pos_fc, "occupancy": occupancy} -def _get_vr_positional_neural_data(positions, data_type, data): +def _get_vr_positional_neural_data(event, positions, data_type, data): """ Get positional neural data for VR behaviour. @@ -1242,10 +1244,15 @@ def _get_vr_positional_neural_data(positions, data_type, data): logging.info(f"\n> Getting positional {data_type}...") # get constants from vd - from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END + from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END,\ + PRE_DARK_LEN + from pixels.behaviours.virtual_reality import Events # get the starting index for each trial (column) starts = positions.iloc[0, :].astype(int) + # if align to dark_onset, actual starting position is before that + if event == Events.dark_on: + starts = starts - PRE_DARK_LEN # create position indices indices = np.arange(0, TUNNEL_RESET+1) # create occupancy array for trials diff --git a/pixels/stream.py b/pixels/stream.py index 3dff052..47f6021 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -571,7 +571,7 @@ def get_positional_data( ) # get positional spike rate, spike count, and occupancy - positional_data = xut.get_vr_positional_data(trials) + positional_data = xut.get_vr_positional_data(event, trials) return positional_data From ee21cd9a0a0c33412e314ce3c9aebb8b90520b38 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 12:49:40 +0100 Subject: [PATCH 415/658] remove redundant imports; use InfFlag for automatic bit shift --- pixels/behaviours/virtual_reality.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index a957a7a..ced354d 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -7,9 +7,7 @@ # as a train of 0s and 1s, we align to the first 1s and the last 1s of a given # event, except for licks since it could only be on or off per frame. -from __future__ import annotations - -import pickle +from enum import IntFlag, auto import numpy as np import matplotlib.pyplot as plt @@ -17,14 +15,10 @@ from vision_in_darkness.base import Outcomes, Worlds, Conditions -from pixels import Experiment, PixelsError -import pixels.signal_utils as signal -from pixels import ioutils +from pixels import PixelsError from pixels.behaviours import Behaviour from pixels.configs import * -from common_utils import file_utils - class ActionLabels: """ From ec38b23c6e8f53286176dc0db0b1720e1ff14105 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 12:50:26 +0100 Subject: [PATCH 416/658] use intflag to automatically shift bit instead of doing it manually --- pixels/behaviours/virtual_reality.py | 81 +++++++++++++++------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index ced354d..6828858 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -20,7 +20,7 @@ from pixels.configs import * -class ActionLabels: +class ActionLabels(IntFlag): """ These actions cover all possible trial types. @@ -35,19 +35,19 @@ class ActionLabels: # TODO jul 4 2024 only label trial type at the first frame of the trial to # make it easier for alignment??? # triggered vr trials - miss_light = 1 << 0 # 1 - miss_dark = 1 << 1 # 2 - triggered_light = 1 << 2 # 4 - triggered_dark = 1 << 3 # 8 - punished_light = 1 << 4 # 16 - punished_dark = 1 << 5 # 32 + miss_light = auto()#1 << 0 # 1 + miss_dark = auto()#1 << 1 # 2 + triggered_light = auto()#1 << 2 # 4 + triggered_dark = auto()#1 << 3 # 8 + punished_light = auto()#1 << 4 # 16 + punished_dark = auto()#1 << 5 # 32 # given reward - default_light = 1 << 6 # 64 - auto_light = 1 << 7 # 128 - auto_dark = 1 << 8 # 256 - reinf_light = 1 << 9 # 512 - reinf_dark = 1 << 10 # 1024 + default_light = auto()#1 << 6 # 64 + auto_light = auto()#1 << 7 # 128 + auto_dark = auto()#1 << 8 # 256 + reinf_light = auto()#1 << 9 # 512 + reinf_dark = auto()#1 << 10 # 1024 # combos miss = miss_light | miss_dark @@ -68,41 +68,48 @@ class ActionLabels: completed_dark = miss_dark | triggered_dark | auto_dark | reinf_dark -class Events: +class Events(IntFlag): """ Defines events that could happen during vr sessions. Events can be added on top of each other. """ # vr events - trial_start = 1 << 0 # 1 - gray_on = 1 << 1 # 2 - gray_off = 1 << 2 # 4 - light_on = 1 << 3 # 8 - light_off = 1 << 4 # 16 - dark_on = 1 << 5 # 32 - dark_off = 1 << 6 # 64 - punish_on = 1 << 7 # 128 - punish_off = 1 << 8 # 256 - trial_end = 1 << 9 # 512 + trial_start = auto()#1 << 0 # 1 + gray_on = auto()#1 << 1 # 2 + gray_off = auto()#1 << 2 # 4 + light_on = auto()#1 << 3 # 8 + light_off = auto()#1 << 4 # 16 + dark_on = auto()#1 << 5 # 32 + dark_off = auto()#1 << 6 # 64 + punish_on = auto()#1 << 7 # 128 + punish_off = auto()#1 << 8 # 256 + trial_end = auto()#1 << 9 # 512 # positional events - pre_dark_end = 1 << 10 # 50 cm - black = 1 << 12 # 0 - 60 cm - wall = 1 << 12 # in between landmarks - landmark1 = 1 << 13 # 110 - 130 cm - landmark2 = 1 << 14 # 190 - 210 cm - landmark3 = 1 << 15 # 270 - 290 cm - landmark4 = 1 << 16 # 350 - 370 cm - landmark5 = 1 << 17 # 430 - 450 cm - reward_zone = 1 << 18 # 460 - 495 cm + pre_dark_end = auto()#1 << 10 # 50 cm + black_off = auto()#1 << 11 # 0 - 60 cm + wall = auto()#1 << 12 # in between landmarks + landmark1_on = auto()#1 << 13 # 110 cm + landmark2_on = auto()#1 << 14 # 190 cm + landmark3_on = auto()#1 << 15 # 270 cm + landmark4_on = auto()#1 << 16 # 350 cm + landmark5_on = auto()#1 << 17 # 430 cm + reward_zone_on = auto()#1 << 18 # 460 cm + + landmark1_off = auto()#1 << 19 # 130 cm + landmark2_off = auto()#1 << 20 # 210 cm + landmark3_off = auto()#1 << 21 # 290 cm + landmark4_off = auto()#1 << 22 # 370 cm + landmark5_off = auto()#1 << 23 # 450 cm + reward_zone_off = auto()#1 << 24 # 495 cm # sensors - valve_open = 1 << 19 # 262144 - valve_closed = 1 << 20 # 524288 - licked = 1 << 21 # 1048576 - #run_start = 1 << 12 - #run_stop = 1 << 13 + valve_open = auto()#1 << 25 # 524288 + valve_closed = auto()#1 << 26 # 1048576 + licked = auto()#1 << 27 # 134217728 + #run_start = 1 << 28 + #run_stop = 1 << 29 # map trial outcome From 8175a22926c91ecaaa033d29a12d76eb7911c6f7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 12:50:45 +0100 Subject: [PATCH 417/658] use unsigned int32 dtype to increase upper bound --- pixels/behaviours/virtual_reality.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 6828858..68ba688 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -131,7 +131,7 @@ class VR(Behaviour): def _extract_action_labels(self, vr, vr_data): # create action label array for actions & events - action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.uint32) # >>>> definitions >>>> # define in gray From f9a843a6ccfc8f885c0fde3649c610a3347794d6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 12:51:16 +0100 Subject: [PATCH 418/658] update refactor version --- pixels/behaviours/virtual_reality.py | 346 ++++++++++++++++++++------- 1 file changed, 264 insertions(+), 82 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 68ba688..1d79475 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -409,103 +409,285 @@ def _check_action_labels(self, vr_data, action_labels, plot=True): return action_labels - ''' - # TODO sep 30 2024: - # refactored code from chatgpt - # needs testing! - def _assign_event_label(self, action_labels, event_times, event_type, column=1): - """ - Helper function to assign event labels to action_labels array. - """ - event_indices = event_times.index - event_on_idx = np.where(event_indices.diff() != 1)[0] +''' +from enum import IntFlag, auto +from typing import NamedTuple, List, Tuple - # Find first and last timepoints for events - event_on_t = event_indices[event_on_idx] - event_off_t = np.append(event_indices[event_on_idx[1:] - 1], event_indices[-1]) +import numpy as np +import pandas as pd - event_on = event_times.index.get_indexer(event_on_t) - event_off = event_times.index.get_indexer(event_off_t) +from vision_in_darkness.base import Outcomes, Worlds, Conditions +from pixels.behaviours import Behaviour +from pixels import PixelsError - action_labels[event_on, column] += event_type['on'] - action_labels[event_off, column] += event_type['off'] - return action_labels +class Events(IntFlag): + """Bit-flags for everything that can happen to the animal in VR.""" + NONE = 0 + trial_start = auto() + gray_on = auto() + gray_off = auto() + light_on = auto() + light_off = auto() + dark_on = auto() + dark_off = auto() + punish_on = auto() + punish_off = auto() + pre_dark_end = auto() + reward_zone = auto() + valve_open = auto() + valve_closed = auto() + lick = auto() - def _map_trial_events(self, action_labels, vr_data, vr): - """ - Maps different trial events like gray, light, dark, and punishments. - """ - # Define event mappings for gray, light, dark, punishments - event_mappings = { - 'gray': {'on': Events.gray_on, 'off': Events.gray_off, - 'condition': vr_data.world_index == Worlds.GRAY}, - 'light': {'on': Events.light_on, 'off': Events.light_off, - 'condition': vr_data.world_index == Worlds.TUNNEL}, - 'dark': {'on': Events.dark_on, 'off': Events.dark_off, - 'condition': (vr_data.world_index == Worlds.DARK_5)\ - | (vr_data.world_index == Worlds.DARK_2_5)\ - | (vr_data.world_index == Worlds.DARK_FULL)}, - 'punish': {'on': Events.punish_on, 'off': Events.punish_off, - 'condition': vr_data.world_index == Worlds.WHITE}, - } - for event_name, event_type in event_mappings.items(): - event_times = vr_data[event_type['condition']] - action_labels = self._assign_event_label(action_labels, event_times, event_type) +class ActionLabels(IntFlag): + """Mutually exclusive trial‐outcome labels, plus helpful combos.""" + NONE = 0 + miss_light = auto() + miss_dark = auto() + triggered_light = auto() + triggered_dark = auto() + punished_light = auto() + punished_dark = auto() + default_light = auto() + auto_light = auto() + auto_dark = auto() + reinf_light = auto() + reinf_dark = auto() + + # handy OR-combos + miss = miss_light | miss_dark + triggered = triggered_light | triggered_dark + punished = punished_light | punished_dark + + light = (miss_light | triggered_light | punished_light + | default_light | auto_light | reinf_light) + dark = (miss_dark | triggered_dark | punished_dark + | auto_dark | reinf_dark) + rewarded_light = (triggered_light | default_light + | auto_light | reinf_light) + rewarded_dark = (triggered_dark | auto_dark + | reinf_dark) + given_light = default_light | auto_light | reinf_light + given_dark = auto_dark | reinf_dark + completed_light = miss_light | triggered_light | default_light \ + | auto_light | reinf_light + completed_dark = miss_dark | triggered_dark | auto_dark \ + | reinf_dark + + +class LabeledEvents(NamedTuple): + """Structured return from _extract_action_labels.""" + timestamps: np.ndarray # shape (N,) + outcome: np.ndarray # shape (N,) of ActionLabels + events: np.ndarray # shape (N,) of Events + - return action_labels +class VR(Behaviour): + """Behaviour subclass that extracts events & action labels from vr_data.""" - def _assign_trial_outcomes(self, action_labels, vr_data, vr): + def _extract_action_labels( + self, + vr_data: pd.DataFrame + ) -> LabeledEvents: """ - Assign outcomes for each trial, including rewards and punishments. + Go over every frame in vr_data and assign: + - `events[i]` := bitmask of Events that occur at frame i + - `outcome[i]` := the trial‐outcome (one and only one ActionLabel) at i """ - for t, trial in enumerate(vr_data.trial_count.unique()): - # Extract trial-specific information - of_trial = (vr_data.trial_count == trial) - trial_idx = np.where(of_trial)[0] - - reward_not_none = (vr_data.reward_type != Outcomes.NONE) - reward_typed = vr_data[of_trial & reward_not_none] - trial_type = int(vr_data[of_trial].trial_type.unique()) - trial_type_str = trial_type_lookup.get(trial_type).lower() - - if reward_typed.size == 0\ - and vr_data[of_trial\ - & (vr_data.world_index == Worlds.WHITE)].size != 0: - # Handle punishment case - outcome = f"punished_{trial_type_str}" - else: - reward_type = int(reward_typed.reward_type.unique()) - outcome = _outcome_map.get(reward_type, "unknown") - - if reward_type == Outcomes.TRIGGERED: - outcome = f"{outcome}_{trial_type_str}" - - action_labels[trial_idx, 0] = getattr(ActionLabels, outcome, 0) - - if reward_type > Outcomes.NONE: - valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) - valve_closed_idx = vr_data.index.get_indexer([reward_typed.index[-1]]) - action_labels[valve_open_idx, 1] += Events.valve_open - action_labels[valve_closed_idx, 1] += Events.valve_closed + N = len(vr_data) + events = np.zeros(N, dtype=np.uint64) + outcome = np.zeros(N, dtype=np.uint64) + + # 1) stamp world‐based events (gray, light, dark, punish) via run masks + for evt, mask in self._world_event_masks(vr_data): + self._stamp_mask(events, mask, evt) + + # 2) stamp positional events (pre‐dark end, reward‐zone) + for evt, mask in self._position_event_masks(vr_data): + self._stamp_mask(events, mask, evt) + + # 3) stamp sensors: lick, valve open/closed + self._stamp_rising(events, vr_data.lick_detect.values, Events.lick) + # (if you have valve signals, do same for them) + + # 4) map each trial’s outcome into the outcome array + outcome_map = self._build_outcome_map() + for trial_id, group in vr_data.groupby("trial_count"): + idxs = group.index.values + flag = self._compute_outcome_flag(group, outcome_map) + outcome[idxs] = flag + + # stamp trial_start at the first frame of each triggered trial: + if flag in (ActionLabels.triggered_light, ActionLabels.triggered_dark): + first_idx = idxs[0] + events[first_idx] |= Events.trial_start + + # 5) return timestamps + two bit‐masked channels + return LabeledEvents( + timestamps = vr_data.index.values, + outcome = outcome.astype(np.uint32), + events = events.astype(np.uint32) + ) - return action_labels - def _extract_action_labels(self, vr, vr_data): + # ------------------------------------------------------------------------- + # Helpers to apply a flag wherever mask is true or on rising edges + # ------------------------------------------------------------------------- + + @staticmethod + def _stamp_mask( + storage: np.ndarray, + mask: np.ndarray, + flag: IntFlag + ) -> None: + """Bitwise‐OR `flag` into `storage` at every True in `mask`.""" + storage[mask] |= flag + + @staticmethod + def _stamp_rising( + storage: np.ndarray, + signal: np.ndarray, + flag: IntFlag + ) -> None: + """ + Find rising‐edge frames in a 0/1 `signal` array (diff == +1) + and stamp `flag` at those indices. """ - Extract action labels from VR data and assign events and outcomes. + edges = np.flatnonzero(np.diff(signal, prepend=0) == 1) + storage[edges] |= flag + + + # ------------------------------------------------------------------------- + # Build lists of (EventFlag, boolean_mask) for world & positional events + # ------------------------------------------------------------------------- + + def _world_event_masks( + self, + df: pd.DataFrame + ) -> List[Tuple[Events, np.ndarray]]: + w = Worlds + return [ + # gray: enters in GRAY, leaves when GRAY ends + (Events.gray_on, self._first_in_run(df.world_index == w.GRAY)), + (Events.gray_off, self._last_in_run (df.world_index == w.GRAY)), + + # white (“punish” region) + (Events.punish_on, self._first_in_run(df.world_index == w.WHITE)), + (Events.punish_off, self._last_in_run (df.world_index == w.WHITE)), + + # light tunnel + (Events.light_on, self._first_in_run(df.world_index == w.TUNNEL)), + (Events.light_off, self._last_in_run (df.world_index == w.TUNNEL)), + + # dark tunnels (could be multiple dark worlds) + (Events.dark_on, self._first_in_run(df.world_index.isin(w.DARKS))), + (Events.dark_off, self._last_in_run (df.world_index.isin(w.DARKS))), + ] + + def _position_event_masks( + self, + df: pd.DataFrame + ) -> List[Tuple[Events, np.ndarray]]: + # pre‐dark end: when position ≥ (start_pos + pre_dark_len) + start_pos = df[df.world_index==Worlds.TUNNEL] \ + .groupby("trial_count")["position_in_tunnel"] \ + .first() + end_vals = start_pos + self.pre_dark_len + # mask per‐trial then merge + pre_dark_mask = np.zeros(len(df), bool) + for trial, thresh in end_vals.items(): + idx = df.trial_count==trial + pre_dark_mask[idx] |= (df.position_in_tunnel[idx] >= thresh) + + # reward zone: any frame at or beyond rz_start + rz_mask = df.position_in_tunnel >= self.rz_start + + return [ + (Events.pre_dark_end, pre_dark_mask), + (Events.reward_zone, rz_mask), + ] + + + # ------------------------------------------------------------------------- + # Utilities for detecting run‐start and run‐end of a boolean mask + # ------------------------------------------------------------------------- + + @staticmethod + def _first_in_run(mask: np.ndarray) -> np.ndarray: """ - action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.int32) + True exactly at the first True of each contiguous run in `mask`. + """ + idx = np.flatnonzero(mask) + if idx.size == 0: + return np.zeros_like(mask) + # a run starts where current True differs from previous True + starts = np.concatenate([[True], + mask[idx[1:]] != mask[idx[:-1]]]) + first_idx = idx[starts] + out = np.zeros_like(mask) + out[first_idx] = True + return out + + @staticmethod + def _last_in_run(mask: np.ndarray) -> np.ndarray: + """ + True exactly at the last True of each contiguous run in `mask`. + """ + idx = np.flatnonzero(mask) + if idx.size == 0: + return np.zeros_like(mask) + ends = np.concatenate([mask[idx[:-1]] != mask[idx[1:]], [True]]) + last_idx = idx[ends] + out = np.zeros_like(mask) + out[last_idx] = True + return out - # Map events - action_labels = self._map_trial_events(action_labels, vr_data, vr) - # Assign trial outcomes - action_labels = self._assign_trial_outcomes(action_labels, vr_data, vr) + # ------------------------------------------------------------------------- + # Outcome‐mapping machinery + # ------------------------------------------------------------------------- - # Add timestamps to action labels - action_labels = np.column_stack((action_labels, vr_data.index.values)) + def _build_outcome_map(self) -> dict: + """ + Returns a dict mapping (Outcomes, Conditions) → ActionLabels. + """ + m = { + (Outcomes.ABORTED_DARK, Conditions.DARK): ActionLabels.miss_dark, + (Outcomes.ABORTED_LIGHT, Conditions.LIGHT): ActionLabels.miss_light, + (Outcomes.PUNISHED, Conditions.LIGHT): ActionLabels.punished_light, + (Outcomes.PUNISHED, Conditions.DARK): ActionLabels.punished_dark, + (Outcomes.AUTO_LIGHT, Conditions.LIGHT): ActionLabels.auto_light, + (Outcomes.DEFAULT, Conditions.LIGHT): ActionLabels.default_light, + (Outcomes.REINF_LIGHT, Conditions.LIGHT): ActionLabels.reinf_light, + (Outcomes.AUTO_DARK, Conditions.DARK): ActionLabels.auto_dark, + (Outcomes.REINF_DARK, Conditions.DARK): ActionLabels.reinf_dark, + + # triggered must include light vs dark + (Outcomes.TRIGGERED, Conditions.LIGHT): ActionLabels.triggered_light, + (Outcomes.TRIGGERED, Conditions.DARK): ActionLabels.triggered_dark, + } + return m - return action_labels - ''' + def _compute_outcome_flag( + self, + trial_df: pd.DataFrame, + outcome_map: dict + ) -> ActionLabels: + """ + Given one trial's DataFrame, look at its reward_type & trial_type + and return the matching ActionLabels member. + """ + rts = trial_df.reward_type.unique() + if rts.size == 0: + # no reward_type → unfinished or last‐trial abort + return ActionLabels.NONE + + rt = Outcomes(int(rts[0])) + cond = Conditions(int(trial_df.trial_type.iloc[0])) + key = (rt, cond) + try: + return outcome_map[key] + except KeyError: + raise PixelsError(f"No outcome mapping for {key}") +''' From 305f42ecac04849912d5e58fb8dcdfc278e4aceb Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:28:07 +0100 Subject: [PATCH 419/658] change name of the class to what it actually represents --- pixels/behaviours/virtual_reality.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 1d79475..7f54a00 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -20,14 +20,14 @@ from pixels.configs import * -class ActionLabels(IntFlag): +class TrialTypes(IntFlag): """ - These actions cover all possible trial types. + These cover all possible trial types. To align trials to more than one action type they can be bitwise OR'd i.e. `miss_light | miss_dark` will match all miss trials. - Actions can NOT be added on top of each other, they should be mutually + Trial types can NOT be added on top of each other, they should be mutually exclusive. """ # TODO jun 7 2024 does the name "action" make sense? @@ -337,8 +337,11 @@ def _first_post_pre_dark(df): & (vr_data[of_trial & in_white].size != 0): # punished outcome outcome = f"punished_{trial_type_str}" - action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) - #action_labels[start_idx, 0] = getattr(ActionLabels, outcome) + outcomes_arr[trial_idx] = getattr(TrialTypes, outcome) + # or only mark the beginning of the trial? + #outcomes_arr[start_idx] = getattr(TrialTypes, outcome) + + #action_labels[start_idx, 0] = getattr(TrialTypes, outcome) # <<<< punished <<<< elif (reward_typed.size == 0)\ @@ -366,8 +369,9 @@ def _first_post_pre_dark(df): """ given & aborted """ outcome = _outcome_map[reward_type] # label outcome - action_labels[trial_idx, 0] = getattr(ActionLabels, outcome) - #action_labels[start_idx, 0] = getattr(ActionLabels, outcome) + outcomes_arr[trial_idx] = getattr(TrialTypes, outcome) + # or only mark the beginning of the trial? + #outcomes_arr[start_idx] = getattr(TrialTypes, outcome) # <<<< non punished <<<< # >>>> non aborted, valve only >>>> From a56f0c75434a3d6a972f8f50d9416840e7f3438a Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:28:43 +0100 Subject: [PATCH 420/658] reorganise positional events --- pixels/behaviours/virtual_reality.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 7f54a00..3b8317f 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -88,20 +88,27 @@ class Events(IntFlag): # positional events pre_dark_end = auto()#1 << 10 # 50 cm - black_off = auto()#1 << 11 # 0 - 60 cm + wall = auto()#1 << 12 # in between landmarks - landmark1_on = auto()#1 << 13 # 110 cm - landmark2_on = auto()#1 << 14 # 190 cm - landmark3_on = auto()#1 << 15 # 270 cm - landmark4_on = auto()#1 << 16 # 350 cm - landmark5_on = auto()#1 << 17 # 430 cm - reward_zone_on = auto()#1 << 18 # 460 cm + black_off = auto()#1 << 11 # 0 - 60 cm + + landmark1_on = auto()#1 << 13 # 110 cm landmark1_off = auto()#1 << 19 # 130 cm + + landmark2_on = auto()#1 << 14 # 190 cm landmark2_off = auto()#1 << 20 # 210 cm + + landmark3_on = auto()#1 << 15 # 270 cm landmark3_off = auto()#1 << 21 # 290 cm + + landmark4_on = auto()#1 << 16 # 350 cm landmark4_off = auto()#1 << 22 # 370 cm + + landmark5_on = auto()#1 << 17 # 430 cm landmark5_off = auto()#1 << 23 # 450 cm + + reward_zone_on = auto()#1 << 18 # 460 cm reward_zone_off = auto()#1 << 24 # 495 cm # sensors From a0092fd72e63649ff3851f6c81c3090c1e338b85 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:30:04 +0100 Subject: [PATCH 421/658] add notes and change name of the array --- pixels/behaviours/virtual_reality.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 3b8317f..845f5f3 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -137,9 +137,11 @@ class Events(IntFlag): class VR(Behaviour): def _extract_action_labels(self, vr, vr_data): - # create action label array for actions & events - action_labels = np.zeros((vr_data.shape[0], 2), dtype=np.uint32) + # NOTE: this func still called _extract_action_labels cuz it is + # inherited from motor data analysis, where each unit is an action. + # create action label array for trial types & events + labels = np.zeros((vr_data.shape[0], 2), dtype=np.uint32) # >>>> definitions >>>> # define in gray in_gray = (vr_data.world_index == Worlds.GRAY) @@ -347,8 +349,6 @@ def _first_post_pre_dark(df): outcomes_arr[trial_idx] = getattr(TrialTypes, outcome) # or only mark the beginning of the trial? #outcomes_arr[start_idx] = getattr(TrialTypes, outcome) - - #action_labels[start_idx, 0] = getattr(TrialTypes, outcome) # <<<< punished <<<< elif (reward_typed.size == 0)\ @@ -397,9 +397,9 @@ def _first_post_pre_dark(df): # <<<< map reward types <<<< # put pixels timestamps in the third column - action_labels = np.column_stack((action_labels, vr_data.index.values)) + labels = np.column_stack((labels, vr_data.index.values)) - return action_labels + return labels def _check_action_labels(self, vr_data, action_labels, plot=True): From d53833c73851a4b5578affc3988a14f4f9da6e1e Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:32:12 +0100 Subject: [PATCH 422/658] create view to refer to array --- pixels/behaviours/virtual_reality.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 845f5f3..4db9455 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -142,6 +142,12 @@ def _extract_action_labels(self, vr, vr_data): # create action label array for trial types & events labels = np.zeros((vr_data.shape[0], 2), dtype=np.uint32) + # create arr view to make it more clear which column i'm writing into + # NOTE: the original array would not get updated if i do fancy indexing + # on the original, i.e., labels[mask, 0] + outcomes_arr = labels[:, 0] + events_arr = labels[:, 1] + # >>>> definitions >>>> # define in gray in_gray = (vr_data.world_index == Worlds.GRAY) From 45c00efe7e658f885de62dac822370a53e19be8e Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:32:44 +0100 Subject: [PATCH 423/658] use bitwise_or to assign values to labels array --- pixels/behaviours/virtual_reality.py | 32 +++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 4db9455..3c22b28 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -178,13 +178,14 @@ def _extract_action_labels(self, vr, vr_data): gray_on_t = gray_idx[grays] # find their index in vr data gray_on = vr_data.index.get_indexer(gray_on_t) - action_labels[gray_on, 1] += Events.gray_on + # bitwise_or.at will do: for each idx in gray_on: dst[idx] |= Events.gray_on + np.bitwise_or.at(events_arr, gray_on, Events.gray_on) # find time for last frame of gray gray_off_t = np.append(gray_idx[grays[1:] - 1], gray_idx[-1]) # find their index in vr data gray_off = vr_data.index.get_indexer(gray_off_t) - action_labels[gray_off, 1] += Events.gray_off + np.bitwise_or.at(events_arr, gray_off, Events.gray_off) # <<<< gray <<<< # >>>> punishment >>>> @@ -197,18 +198,18 @@ def _extract_action_labels(self, vr, vr_data): punish_on_t = punish_idx[punishes] # find their index in vr data punish_on = vr_data.index.get_indexer(punish_on_t) - action_labels[punish_on, 1] += Events.punish_on + np.bitwise_or.at(events_arr, punish_on, Events.punish_on) # find time for last frame of punish punish_off_t = np.append(punish_idx[punishes[1:] - 1], punish_idx[-1]) # find their index in vr data punish_off = vr_data.index.get_indexer(punish_off_t) - action_labels[punish_off, 1] += Events.punish_off + np.bitwise_or.at(events_arr, punish_off, Events.punish_off) # <<<< punishment <<<< # >>>> trial ends >>>> # trial ends right before punishment starts - action_labels[punish_on-1, 1] += Events.trial_end + np.bitwise_or.at(events_arr, punish_on-1, Events.trial_end) # for non punished trials, right before gray on is when trial ends, and # the last frame of the session @@ -218,7 +219,7 @@ def _extract_action_labels(self, vr, vr_data): no_punished_t = pre_gray_on.drop(punish_off_t).index # get index of trial ends in non punished trials no_punished_idx = vr_data.index.get_indexer(no_punished_t) - action_labels[no_punished_idx, 1] += Events.trial_end + np.bitwise_or.at(events_arr, no_punished_idx, Events.trial_end) # <<<< trial ends <<<< # >>>> light >>>> @@ -230,7 +231,7 @@ def _extract_action_labels(self, vr, vr_data): light_on_t = light_idx[lights] # get index of when light turns on light_on = vr_data.index.get_indexer(light_on_t) - action_labels[light_on, 1] += Events.light_on + np.bitwise_or.at(events_arr, light_on, Events.light_on) # get interval of possible starting position start_interval = int(vr.meta_item('rand_start_int')) @@ -240,7 +241,7 @@ def _extract_action_labels(self, vr, vr_data): vr_data.iloc[light_on].position_in_tunnel % start_interval == 0 )[0]] # label trial starts - action_labels[trial_starts, 1] += Events.trial_start + np.bitwise_or.at(events_arr, trial_starts, Events.trial_start) if not trial_starts.size == vr_data[in_tunnel].trial_count.max(): raise PixelsError(f"Number of trials does not equal to " @@ -251,7 +252,7 @@ def _extract_action_labels(self, vr, vr_data): # last frame of light light_off_t = np.append(light_idx[lights[1:] - 1], light_idx[-1]) light_off = vr_data.index.get_indexer(light_off_t) - action_labels[light_off, 1] += Events.light_off + np.bitwise_or.at(events_arr, light_off, Events.light_off) # <<<< light <<<< # NOTE: dark trials should in theory have EQUAL index pre_dark_end_idx @@ -274,21 +275,22 @@ def _extract_action_labels(self, vr, vr_data): # first frame of dark dark_on_t = dark_idx[darks] dark_on = vr_data.index.get_indexer(dark_on_t) - action_labels[dark_on, 1] += Events.dark_on + np.bitwise_or.at(events_arr, dark_on, Events.dark_on) # last frame of dark dark_off_t = np.append(dark_idx[darks[1:] - 1], dark_idx[-1]) dark_off = vr_data.index.get_indexer(dark_off_t) - action_labels[dark_off, 1] += Events.dark_off + np.bitwise_or.at(events_arr, dark_off, Events.dark_off) # <<<< dark <<<< # >>>> licks >>>> lick_onsets = np.diff(vr_data.lick_detect, prepend=0) licked_idx = np.where(lick_onsets == 1)[0] - action_labels[licked_idx, 1] += Events.licked + np.bitwise_or.at(events_arr, licked_idx, Events.licked) # <<<< licks <<<< # TODO jun 27 2024 positional events and valve events needs mapping + # >>>> positional event mapping >>>> # >>>> Event: end of pre dark length >>>> # NOTE: AL remove pre_dark_len + 10cm of his data @@ -315,7 +317,7 @@ def _first_post_pre_dark(df): pre_dark_end_idx = vr_data.index.get_indexer( pre_dark_end_t.dropna().astype(int) ) - action_labels[pre_dark_end_idx, 1] += Events.pre_dark_end + np.bitwise_or.at(events_arr, pre_dark_end_idx, Events.pre_dark_end) # <<<< Event: end of pre dark length <<<< # >>>> Event: reward zone >>>> @@ -392,12 +394,12 @@ def _first_post_pre_dark(df): if reward_type > Outcomes.NONE: # map valve open valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) - action_labels[valve_open_idx, 1] += Events.valve_open + np.bitwise_or.at(events_arr, valve_open_idx, Events.valve_open) # map valve closed valve_closed_idx = vr_data.index.get_indexer( [reward_typed.index[-1]] ) - action_labels[valve_closed_idx, 1] += Events.valve_closed + np.bitwise_or.at(events_arr, valve_closed_idx, Events.valve_closed) # <<<< non aborted, valve only <<<< # <<<< map reward types <<<< From 6cdffd1a770f17f9c0ab805858d62cb05d3c8f4e Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:33:05 +0100 Subject: [PATCH 424/658] map more positional events --- pixels/behaviours/virtual_reality.py | 31 ++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 3c22b28..b58b608 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -320,11 +320,34 @@ def _first_post_pre_dark(df): np.bitwise_or.at(events_arr, pre_dark_end_idx, Events.pre_dark_end) # <<<< Event: end of pre dark length <<<< + # >>>> Event: landmark 1 >>>> + + # <<<< Event: landmark 1 <<<< + # >>>> Event: reward zone >>>> - assert 0 - # TODO jun 2 2025: - # do positional event mapping - vr_data.position_in_tunnel + # all indices in reward zone + in_zone = ( + vr_data.position_in_tunnel >= vr.reward_zone_start + ) & ( + vr_data.position_in_tunnel <= vr.reward_zone_end + ) + # reward zone on + zone_on_t = ( + vr_data[in_zone] + .groupby("trial_count") + .apply(lambda g: g.index.min()) + ) + zone_on_idx = vr_data.index.get_indexer(zone_on_t) + np.bitwise_or.at(events_arr, zone_on_idx, Events.reward_zone_on) + + # reward zone off + zone_off_t = ( + vr_data[in_zone] + .groupby("trial_count") + .apply(lambda g: g.index.max()) + ) + zone_off_idx = vr_data.index.get_indexer(zone_off_t) + 1 + np.bitwise_or.at(events_arr, zone_off_idx, Events.reward_zone_off) # <<<< Event: reward zone <<<< logging.info("\n>> Mapping vr action times...") From 2455aa5fe0726314c7fc831e164bafcaad3e57c4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:33:30 +0100 Subject: [PATCH 425/658] change name to represent what it actually is --- pixels/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 47f6021..3146be5 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -107,14 +107,14 @@ def _get_aligned_trials( spikes = self.get_spike_times(units) # get action and event label file - actions = action_labels["outcome"] + outcomes = action_labels["outcome"] events = action_labels["events"] # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to convert # it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE timestamps = action_labels["timestamps"] # select frames of wanted trial type - trials = np.where(np.bitwise_and(actions, label))[0] + trials = np.where(np.bitwise_and(outcomes, label))[0] # map starts by event starts = np.where(np.bitwise_and(events, event))[0] # map starts by end event From 7947a688657b672ff32e087ded37751e1ecbff94 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 3 Jun 2025 14:37:39 +0100 Subject: [PATCH 426/658] use last frame in zone as reward zone off --- pixels/behaviours/virtual_reality.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index b58b608..dfc618e 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -331,7 +331,7 @@ def _first_post_pre_dark(df): ) & ( vr_data.position_in_tunnel <= vr.reward_zone_end ) - # reward zone on + # first frame in reward zone zone_on_t = ( vr_data[in_zone] .groupby("trial_count") @@ -340,13 +340,13 @@ def _first_post_pre_dark(df): zone_on_idx = vr_data.index.get_indexer(zone_on_t) np.bitwise_or.at(events_arr, zone_on_idx, Events.reward_zone_on) - # reward zone off + # last frame in reward zone zone_off_t = ( vr_data[in_zone] .groupby("trial_count") .apply(lambda g: g.index.max()) ) - zone_off_idx = vr_data.index.get_indexer(zone_off_t) + 1 + zone_off_idx = vr_data.index.get_indexer(zone_off_t) np.bitwise_or.at(events_arr, zone_off_idx, Events.reward_zone_off) # <<<< Event: reward zone <<<< From ab640250a65881d3e41a6234032489386c2c2393 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:20:04 +0100 Subject: [PATCH 427/658] use NamedTuple --- pixels/behaviours/virtual_reality.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index dfc618e..e1b269c 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -8,6 +8,7 @@ # event, except for licks since it could only be on or off per frame. from enum import IntFlag, auto +from typing import NamedTuple import numpy as np import matplotlib.pyplot as plt @@ -519,11 +520,17 @@ class ActionLabels(IntFlag): class LabeledEvents(NamedTuple): - """Structured return from _extract_action_labels.""" - timestamps: np.ndarray # shape (N,) - outcome: np.ndarray # shape (N,) of ActionLabels - events: np.ndarray # shape (N,) of Events - + """Return type: timestamps + bitfields for outcome & events.""" + timestamps: np.ndarray # shape (N,) + outcome: np.ndarray # shape (N,) dtype uint32 + events: np.ndarray # shape (N,) dtype uint32 + +class WorldMasks(NamedTuple): + in_gray: pd.Series + in_dark: pd.Series + in_white: pd.Series + in_light: pd.Series + in_tunnel: pd.Series class VR(Behaviour): """Behaviour subclass that extracts events & action labels from vr_data.""" From 903ec4c5d2fe3f399ea360cf7f3f9b1974a61aad Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:20:22 +0100 Subject: [PATCH 428/658] add none type --- pixels/behaviours/virtual_reality.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index e1b269c..6ae1bb1 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -36,6 +36,7 @@ class TrialTypes(IntFlag): # TODO jul 4 2024 only label trial type at the first frame of the trial to # make it easier for alignment??? # triggered vr trials + NONE = 0 miss_light = auto()#1 << 0 # 1 miss_dark = auto()#1 << 1 # 2 triggered_light = auto()#1 << 2 # 4 @@ -75,6 +76,7 @@ class Events(IntFlag): Events can be added on top of each other. """ + NONE = 0 # vr events trial_start = auto()#1 << 0 # 1 gray_on = auto()#1 << 1 # 2 From b8a5686af12de0698689e7520fd299d83ebab383 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:21:05 +0100 Subject: [PATCH 429/658] use enum & auto to map labels --- pixels/behaviours/virtual_reality.py | 64 +++++++++++++++------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 6ae1bb1..7210576 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -78,48 +78,52 @@ class Events(IntFlag): """ NONE = 0 # vr events - trial_start = auto()#1 << 0 # 1 - gray_on = auto()#1 << 1 # 2 - gray_off = auto()#1 << 2 # 4 - light_on = auto()#1 << 3 # 8 - light_off = auto()#1 << 4 # 16 - dark_on = auto()#1 << 5 # 32 - dark_off = auto()#1 << 6 # 64 - punish_on = auto()#1 << 7 # 128 - punish_off = auto()#1 << 8 # 256 - trial_end = auto()#1 << 9 # 512 + trial_start = auto() # 1 + gray_on = auto() # 2 + gray_off = auto() # 4 + light_on = auto() # 8 + light_off = auto() # 16 + dark_on = auto() # 32 + dark_off = auto() # 64 + punish_on = auto() # 128 + punish_off = auto() # 256 + trial_end = auto() # 512 # positional events - pre_dark_end = auto()#1 << 10 # 50 cm + pre_dark_end = auto()# 50 cm - wall = auto()#1 << 12 # in between landmarks + # TODO jun 4 2025: + # how to mark wall? + wall = auto()# in between landmarks - black_off = auto()#1 << 11 # 0 - 60 cm + # black wall + landmark0_on = auto()# wherever trial starts, before 60cm + landmark0_off = auto()# 60 cm - landmark1_on = auto()#1 << 13 # 110 cm - landmark1_off = auto()#1 << 19 # 130 cm + landmark1_on = auto()# 110 cm + landmark1_off = auto()# 130 cm - landmark2_on = auto()#1 << 14 # 190 cm - landmark2_off = auto()#1 << 20 # 210 cm + landmark2_on = auto()# 190 cm + landmark2_off = auto()# 210 cm - landmark3_on = auto()#1 << 15 # 270 cm - landmark3_off = auto()#1 << 21 # 290 cm + landmark3_on = auto()# 270 cm + landmark3_off = auto()# 290 cm - landmark4_on = auto()#1 << 16 # 350 cm - landmark4_off = auto()#1 << 22 # 370 cm + landmark4_on = auto()# 350 cm + landmark4_off = auto()# 370 cm - landmark5_on = auto()#1 << 17 # 430 cm - landmark5_off = auto()#1 << 23 # 450 cm + landmark5_on = auto()# 430 cm + landmark5_off = auto()# 450 cm - reward_zone_on = auto()#1 << 18 # 460 cm - reward_zone_off = auto()#1 << 24 # 495 cm + reward_zone_on = auto()# 460 cm + reward_zone_off = auto()# 495 cm # sensors - valve_open = auto()#1 << 25 # 524288 - valve_closed = auto()#1 << 26 # 1048576 - licked = auto()#1 << 27 # 134217728 - #run_start = 1 << 28 - #run_stop = 1 << 29 + valve_open = auto() + valve_closed = auto() + licked = auto() + #run_start = auto() + #run_stop = auto() # map trial outcome From 9bf0dd91b1760c852940f41c5481c4d4855a95d0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:24:10 +0100 Subject: [PATCH 430/658] use unsigned int32 to ensure no overflowing --- pixels/behaviours/virtual_reality.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 7210576..8357ed1 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -141,19 +141,15 @@ class Events(IntFlag): # function to look up trial type trial_type_lookup = {v: k for k, v in vars(Conditions).items()} +''' class VR(Behaviour): def _extract_action_labels(self, vr, vr_data): # NOTE: this func still called _extract_action_labels cuz it is # inherited from motor data analysis, where each unit is an action. - # create action label array for trial types & events - labels = np.zeros((vr_data.shape[0], 2), dtype=np.uint32) - # create arr view to make it more clear which column i'm writing into - # NOTE: the original array would not get updated if i do fancy indexing - # on the original, i.e., labels[mask, 0] - outcomes_arr = labels[:, 0] - events_arr = labels[:, 1] + events_arr = np.zeros(len(vr_data), dtype=np.uint32) + outcomes_arr = np.zeros(len(vr_data), dtype=np.uint32) # >>>> definitions >>>> # define in gray From 48052702b96bb7708c6eae9f24552b178844e0a3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:24:34 +0100 Subject: [PATCH 431/658] use DARKS --- pixels/behaviours/virtual_reality.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 8357ed1..51f92de 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -155,9 +155,10 @@ def _extract_action_labels(self, vr, vr_data): # define in gray in_gray = (vr_data.world_index == Worlds.GRAY) # define in dark - in_dark = (vr_data.world_index == Worlds.DARK_5)\ - | (vr_data.world_index == Worlds.DARK_2_5)\ - | (vr_data.world_index == Worlds.DARK_FULL) + in_dark = vr_data.world_index.isin(Worlds.DARKS) + #in_dark = (vr_data.world_index == Worlds.DARK_5)\ + # | (vr_data.world_index == Worlds.DARK_2_5)\ + # | (vr_data.world_index == Worlds.DARK_FULL) # define in white in_white = (vr_data.world_index == Worlds.WHITE) # define in tunnel From 5ec792bf7c33aad5d7a550a1044ce1bef9810dde Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:24:48 +0100 Subject: [PATCH 432/658] add todo --- pixels/behaviours/virtual_reality.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 51f92de..78ed011 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -325,6 +325,10 @@ def _first_post_pre_dark(df): # <<<< Event: end of pre dark length <<<< # >>>> Event: landmark 1 >>>> + # jun 3 2025: + # CONTINUE HERE! + # might as well implement the new refactored code before extend more + # landmarks... # <<<< Event: landmark 1 <<<< From 1fc4307458fd7c4e43ab8b7d81acf80216d60cc7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:25:10 +0100 Subject: [PATCH 433/658] return more structured --- pixels/behaviours/virtual_reality.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 78ed011..6f25056 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -435,10 +435,12 @@ def _first_post_pre_dark(df): # <<<< map reward types <<<< - # put pixels timestamps in the third column - labels = np.column_stack((labels, vr_data.index.values)) - - return labels + # return typed arrays + return LabeledEvents( + outcome = outcomes_arr, + events = events_arr, + timestamps = vr_data.index.values, + ) def _check_action_labels(self, vr_data, action_labels, plot=True): From 227d12e5f0254a5f302f7c57f2750dbf23ec9ff4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:25:26 +0100 Subject: [PATCH 434/658] remove redundant --- pixels/behaviours/virtual_reality.py | 65 ---------------------------- 1 file changed, 65 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 6f25056..4c8db9c 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -460,72 +460,7 @@ def _check_action_labels(self, vr_data, action_labels, plot=True): plt.show() return action_labels - ''' -from enum import IntFlag, auto -from typing import NamedTuple, List, Tuple - -import numpy as np -import pandas as pd - -from vision_in_darkness.base import Outcomes, Worlds, Conditions -from pixels.behaviours import Behaviour -from pixels import PixelsError - - -class Events(IntFlag): - """Bit-flags for everything that can happen to the animal in VR.""" - NONE = 0 - trial_start = auto() - gray_on = auto() - gray_off = auto() - light_on = auto() - light_off = auto() - dark_on = auto() - dark_off = auto() - punish_on = auto() - punish_off = auto() - pre_dark_end = auto() - reward_zone = auto() - valve_open = auto() - valve_closed = auto() - lick = auto() - - -class ActionLabels(IntFlag): - """Mutually exclusive trial‐outcome labels, plus helpful combos.""" - NONE = 0 - miss_light = auto() - miss_dark = auto() - triggered_light = auto() - triggered_dark = auto() - punished_light = auto() - punished_dark = auto() - default_light = auto() - auto_light = auto() - auto_dark = auto() - reinf_light = auto() - reinf_dark = auto() - - # handy OR-combos - miss = miss_light | miss_dark - triggered = triggered_light | triggered_dark - punished = punished_light | punished_dark - - light = (miss_light | triggered_light | punished_light - | default_light | auto_light | reinf_light) - dark = (miss_dark | triggered_dark | punished_dark - | auto_dark | reinf_dark) - rewarded_light = (triggered_light | default_light - | auto_light | reinf_light) - rewarded_dark = (triggered_dark | auto_dark - | reinf_dark) - given_light = default_light | auto_light | reinf_light - given_dark = auto_dark | reinf_dark - completed_light = miss_light | triggered_light | default_light \ - | auto_light | reinf_light - completed_dark = miss_dark | triggered_dark | auto_dark \ - | reinf_dark class LabeledEvents(NamedTuple): From fed00667a86050b6f437429006b4a0b0abc46ee3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:25:52 +0100 Subject: [PATCH 435/658] define world mask so it can be used globally --- pixels/behaviours/virtual_reality.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 4c8db9c..b8d2f7f 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -477,7 +477,26 @@ class WorldMasks(NamedTuple): in_tunnel: pd.Series class VR(Behaviour): - """Behaviour subclass that extracts events & action labels from vr_data.""" + """Behaviour subclass to extract action labels & events from vr_data.""" + def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: + # define in gray + in_gray = (df.world_index == Worlds.GRAY) + # define in dark + in_dark = (df.world_index.isin(Worlds.DARKS)) + # define in white + in_white = (df.world_index == Worlds.WHITE) + # define in light + in_light = (df.world_index == Worlds.TUNNEL) + # define in tunnel + in_tunnel = (~in_gray & ~in_white) + + return WorldMasks( + in_gray = in_gray, + in_dark = in_dark, + in_white = in_white, + in_light = in_light, + in_tunnel = in_tunnel, + ) def _extract_action_labels( self, From 98a749a90b54b83b4c7d4997905a99a9d8f198a3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:26:12 +0100 Subject: [PATCH 436/658] input vr behaviour obj --- pixels/behaviours/virtual_reality.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index b8d2f7f..10cc1b8 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -501,6 +501,8 @@ def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: def _extract_action_labels( self, vr_data: pd.DataFrame + session, + data: pd.DataFrame ) -> LabeledEvents: """ Go over every frame in vr_data and assign: From 245abd1fa6c47f13e45808492b023487e9f71d37 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:27:17 +0100 Subject: [PATCH 437/658] use uint32 not 64 --- pixels/behaviours/virtual_reality.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 10cc1b8..922dabb 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -500,7 +500,6 @@ def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: def _extract_action_labels( self, - vr_data: pd.DataFrame session, data: pd.DataFrame ) -> LabeledEvents: @@ -509,9 +508,9 @@ def _extract_action_labels( - `events[i]` := bitmask of Events that occur at frame i - `outcome[i]` := the trial‐outcome (one and only one ActionLabel) at i """ - N = len(vr_data) - events = np.zeros(N, dtype=np.uint64) - outcome = np.zeros(N, dtype=np.uint64) + N = len(data) + events_arr = np.zeros(N, dtype=np.uint32) + outcomes_arr = np.zeros(N, dtype=np.uint32) # 1) stamp world‐based events (gray, light, dark, punish) via run masks for evt, mask in self._world_event_masks(vr_data): From b90dfb19312d09fc01173fe6540fc452ca1a77a8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:28:11 +0100 Subject: [PATCH 438/658] fix main logic --- pixels/behaviours/virtual_reality.py | 55 ++++++++++++++++------------ 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 922dabb..f506aee 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -512,35 +512,44 @@ def _extract_action_labels( events_arr = np.zeros(N, dtype=np.uint32) outcomes_arr = np.zeros(N, dtype=np.uint32) - # 1) stamp world‐based events (gray, light, dark, punish) via run masks - for evt, mask in self._world_event_masks(vr_data): - self._stamp_mask(events, mask, evt) + # world index based events + for event, idx in self._world_event_indices(data).items(): + mask = self._get_index(data, idx.to_numpy()) + self._stamp_mask(events_arr, mask, event) - # 2) stamp positional events (pre‐dark end, reward‐zone) - for evt, mask in self._position_event_masks(vr_data): - self._stamp_mask(events, mask, evt) + # positional events (pre‐dark end, landmarks, reward‐zone) + for event, idx in self._position_event_indices(session, data).items(): + mask = self._get_index(data, idx.to_numpy()) + self._stamp_mask(events_arr, mask, event) - # 3) stamp sensors: lick, valve open/closed - self._stamp_rising(events, vr_data.lick_detect.values, Events.lick) - # (if you have valve signals, do same for them) + # sensors: lick rising‐edge + self._stamp_rising(events_arr, data.lick_detect.values, Events.licked) - # 4) map each trial’s outcome into the outcome array + # map trial outcomes outcome_map = self._build_outcome_map() - for trial_id, group in vr_data.groupby("trial_count"): - idxs = group.index.values - flag = self._compute_outcome_flag(group, outcome_map) - outcome[idxs] = flag + for trial_id, group in data.groupby("trial_count"): + trial_t = group.index.values + flag, valve_events = self._compute_outcome_flag( + trial_id, + group, + outcome_map, + session, + ) + + # save outcomes + idx = self._get_index(data, trial_t) + outcomes_arr[idx] = flag + + if len(valve_events) > 0: + for v_event, v_idx in valve_events.items(): + valve_mask = self._get_index(data, v_idx) + self._stamp_mask(events_arr, valve_mask, v_event) - # stamp trial_start at the first frame of each triggered trial: - if flag in (ActionLabels.triggered_light, ActionLabels.triggered_dark): - first_idx = idxs[0] - events[first_idx] |= Events.trial_start - - # 5) return timestamps + two bit‐masked channels + # return typed arrays return LabeledEvents( - timestamps = vr_data.index.values, - outcome = outcome.astype(np.uint32), - events = events.astype(np.uint32) + timestamps = data.index.values, + outcome = outcomes_arr, + events = events_arr, ) From 17d45ad03a411d9096e14e954644dd88fc684510 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:29:01 +0100 Subject: [PATCH 439/658] add doc --- pixels/behaviours/virtual_reality.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index f506aee..2fdd5e8 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -504,9 +504,10 @@ def _extract_action_labels( data: pd.DataFrame ) -> LabeledEvents: """ - Go over every frame in vr_data and assign: - - `events[i]` := bitmask of Events that occur at frame i - - `outcome[i]` := the trial‐outcome (one and only one ActionLabel) at i + Go over every frame in data and assign: + - `events_arr[i]` := bitmask of Events that occur at frame i + - `outcomes_arr[i]` := the trial‐outcome (one and only one TrialType) + at i """ N = len(data) events_arr = np.zeros(N, dtype=np.uint32) From 1a5c3aeea6af9f391d062b1fd18d0f031e325fd4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:33:46 +0100 Subject: [PATCH 440/658] add stamping methods --- pixels/behaviours/virtual_reality.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 2fdd5e8..daeccba 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -555,30 +555,23 @@ def _extract_action_labels( # ------------------------------------------------------------------------- - # Helpers to apply a flag wherever mask is true or on rising edges + # Core stamping helpers # ------------------------------------------------------------------------- @staticmethod - def _stamp_mask( - storage: np.ndarray, - mask: np.ndarray, - flag: IntFlag - ) -> None: """Bitwise‐OR `flag` into `storage` at every True in `mask`.""" - storage[mask] |= flag + def _stamp_mask(array: np.ndarray, mask: np.ndarray, flag: IntFlag): + np.bitwise_or.at(array, mask, flag) @staticmethod - def _stamp_rising( - storage: np.ndarray, - signal: np.ndarray, - flag: IntFlag - ) -> None: """ Find rising‐edge frames in a 0/1 `signal` array (diff == +1) and stamp `flag` at those indices. """ + def _stamp_rising(array: np.ndarray, signal: np.ndarray, flag: IntFlag): + # extract edges edges = np.flatnonzero(np.diff(signal, prepend=0) == 1) - storage[edges] |= flag + np.bitwise_or.at(array, edges, flag) # ------------------------------------------------------------------------- From 6f4b4c90d87d94e1b6fde5bccd3f2e96a3cb1d41 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:34:35 +0100 Subject: [PATCH 441/658] use helper func to get world index mask --- pixels/behaviours/virtual_reality.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index daeccba..ef1207c 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -575,30 +575,12 @@ def _stamp_rising(array: np.ndarray, signal: np.ndarray, flag: IntFlag): # ------------------------------------------------------------------------- - # Build lists of (EventFlag, boolean_mask) for world & positional events + # Build world‐based event masks # ------------------------------------------------------------------------- - def _world_event_masks( - self, - df: pd.DataFrame - ) -> List[Tuple[Events, np.ndarray]]: - w = Worlds - return [ - # gray: enters in GRAY, leaves when GRAY ends - (Events.gray_on, self._first_in_run(df.world_index == w.GRAY)), - (Events.gray_off, self._last_in_run (df.world_index == w.GRAY)), - - # white (“punish” region) - (Events.punish_on, self._first_in_run(df.world_index == w.WHITE)), - (Events.punish_off, self._last_in_run (df.world_index == w.WHITE)), - # light tunnel - (Events.light_on, self._first_in_run(df.world_index == w.TUNNEL)), - (Events.light_off, self._last_in_run (df.world_index == w.TUNNEL)), + world_masks = self._get_world_masks(df) - # dark tunnels (could be multiple dark worlds) - (Events.dark_on, self._first_in_run(df.world_index.isin(w.DARKS))), - (Events.dark_off, self._last_in_run (df.world_index.isin(w.DARKS))), ] def _position_event_masks( From cc4c808791192ce02d4e9366ccf9c41fb9d9a7f4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:35:46 +0100 Subject: [PATCH 442/658] get indices of world-index based events --- pixels/behaviours/virtual_reality.py | 72 +++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index ef1207c..530c13c 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -578,12 +578,82 @@ def _stamp_rising(array: np.ndarray, signal: np.ndarray, flag: IntFlag): # Build world‐based event masks # ------------------------------------------------------------------------- + def _world_event_indices( + self, + df: pd.DataFrame + ) -> dict[Events, pd.Series]: + """ + Build (event_flag, boolean_mask) for gray, white (punish), + light tunnel, and dark tunnels, but *per trial*, + so each trial contributes its own on/off. + """ + masks: dict[Events, pd.Series] = {} + N = len(df) world_masks = self._get_world_masks(df) + specs = [ + # (which‐world‐test, on‐event, off‐event) + (world_masks.in_gray, Events.gray_on, Events.gray_off), + (world_masks.in_white, Events.punish_on, Events.punish_off), + (world_masks.in_dark, Events.dark_on, Events.dark_off), ] - def _position_event_masks( + for bool_mask, event_on, event_off in specs: + # compute the per‐trial first‐index + trials = df[bool_mask].groupby("trial_count") + masks[event_on] = trials.apply(self._first_index) + # compute the per‐trial last‐index + masks[event_off] = trials.apply(self._last_index) + + + # >>>> trial ends >>>> + gray_on_t = masks[Events.gray_on] + + # for non punished trials, right before gray on is when trial ends, plus + # the last frame of the session + trial_ends_t = gray_on_t.copy() + trial_ends_t.iloc[:-1] = (gray_on_t[1:] - 1).to_numpy() + trial_ends_t.iloc[-1] = df.index[-1] + + # trial ends right before punishment starts + punish_on_t = masks[Events.punish_on] + trial_ends_t.loc[punish_on_t.index] = punish_on_t - 1 + + masks[Events.trial_end] = trial_ends_t + # <<<< trial ends <<<< + + # >>> handle light on separately >>> + # build a run id that increments whenever in_light flips + run_id = world_masks.in_light.ne( + world_masks.in_light.shift(fill_value=False) + ).cumsum() + + # restrict to just the True‐runs + light_runs = df[world_masks.in_light].copy() + light_runs['run_id'] = run_id[world_masks.in_light] + + # for each (trial, run_id) get the on‐times and off‐times + firsts = ( + light_runs + .groupby(["trial_count", "run_id"]) + .apply(self._first_index) + ) + light_ons = firsts.droplevel("run_id") + lasts = ( + light_runs + .groupby(["trial_count", "run_id"]) + .apply(self._last_index) + ) + light_offs = lasts.droplevel("run_id") + + masks[Events.light_on] = light_ons + masks[Events.light_off] = light_offs + # <<< handle light on separately <<< + + return masks + + self, df: pd.DataFrame ) -> List[Tuple[Events, np.ndarray]]: From 13345c9f5918dced874e3a09cfaaccf13f06fd08 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:36:28 +0100 Subject: [PATCH 443/658] add positional events --- pixels/behaviours/virtual_reality.py | 132 +++++++++++++++++++++++---- 1 file changed, 113 insertions(+), 19 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 530c13c..aa11c69 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -654,27 +654,121 @@ def _world_event_indices( return masks + # ------------------------------------------------------------------------- + # Build positional event masks (landmarks, pre_dark_end, reward zone) + # ------------------------------------------------------------------------- + + def _position_event_indices( self, + session, df: pd.DataFrame - ) -> List[Tuple[Events, np.ndarray]]: - # pre‐dark end: when position ≥ (start_pos + pre_dark_len) - start_pos = df[df.world_index==Worlds.TUNNEL] \ - .groupby("trial_count")["position_in_tunnel"] \ - .first() - end_vals = start_pos + self.pre_dark_len - # mask per‐trial then merge - pre_dark_mask = np.zeros(len(df), bool) - for trial, thresh in end_vals.items(): - idx = df.trial_count==trial - pre_dark_mask[idx] |= (df.position_in_tunnel[idx] >= thresh) - - # reward zone: any frame at or beyond rz_start - rz_mask = df.position_in_tunnel >= self.rz_start - - return [ - (Events.pre_dark_end, pre_dark_mask), - (Events.reward_zone, rz_mask), - ] + ) -> dict[Events, pd.Series]: + masks: dict[Events, pd.Series] = {} + + in_tunnel = self._get_world_masks(df).in_tunnel + + def _first_post_mark(group_df, check_marks): + if isinstance(check_marks, pd.Series): + group_id = group_df.name + mark = check_marks.loc[group_id] + elif isinstance(check_marks, int): + mark = check_marks + + # mask and pick the first index + mask = (group_df["position_in_tunnel"] >= mark) + if not mask.any(): + return None + return group_df.index[mask].min() + + # >>> distance travelled before dark onset per trial >>> + # NOTE: this applies to light trials too to keep data symetrical + # AL remove pre_dark_len + 10cm in all his data + in_tunnel_trials = df[in_tunnel].groupby("trial_count") + # get start of each trial + trial_starts = in_tunnel_trials.apply(self._first_index) + masks[Events.trial_start] = trial_starts + + # get starting positions of all trials + start_pos = in_tunnel_trials["position_in_tunnel"].first() + # plus pre dark + end_pre_dark = start_pos + session.pre_dark_len + + pre_dark_end_t = in_tunnel_trials.apply( + lambda df: _first_post_mark(df, end_pre_dark) + ).dropna().astype(int) + + masks[Events.pre_dark_end] = pre_dark_end_t + # >>> distance travelled before dark onset per trial >>> + + # >>> landmark 0 black wall >>> + black_off = session.landmarks[0] + + starts_before_black = (start_pos < black_off) + early_ids = start_pos[starts_before_black].index + early_trials = df[ + in_tunnel & df.trial_count.isin(early_ids) + ].groupby("trial_count") + + # first frame of black wall + landmark0_on = early_trials.apply( + self._first_index + ) + masks[Events.landmark0_on] = landmark0_on + + # last frame of black wall + landmark0_off = early_trials.apply( + lambda df: _first_post_mark(df, black_off) + ) + masks[Events.landmark0_off] = landmark0_off + # <<< landmark 0 black wall <<< + + # >>> landmarks 1 to 5 >>> + landmarks = session.landmarks[1:] + + for l, landmark in enumerate(landmarks): + if l % 2 != 0: + continue + + landmark_idx = l // 2 + 1 + + # even idx on, odd idx off + landmark_on = landmark + landmark_off = landmarks[l + 1] + + in_landmark = ( + (df.position_in_tunnel >= landmark_on) & + (df.position_in_tunnel <= landmark_off) + ) + + landmark_on = df[in_landmark].groupby("trial_count").apply( + self._first_index + ) + landmark_off = df[in_landmark].groupby("trial_count").apply( + self._last_index + ) + + masks[getattr(Events, f"landmark{landmark_idx}_on")] = landmark_on + masks[getattr(Events, f"landmark{landmark_idx}_off")] = landmark_off + # <<< landmarks 1 to 5 <<< + + # >>> reward zone >>> + in_zone = ( + df.position_in_tunnel >= session.reward_zone_start + ) & ( + df.position_in_tunnel <= session.reward_zone_end + ) + in_zone_trials = df[in_zone].groupby("trial_count") + + # first frame in reward zone + zone_on_t = in_zone_trials.apply(self._first_index) + masks[Events.reward_zone_on] = zone_on_t + + # last frame in reward zone + zone_off_t = in_zone_trials.apply(self._last_index) + masks[Events.reward_zone_off] = zone_off_t + # <<< reward zone <<< + + return masks # ------------------------------------------------------------------------- From 4394964fa64fecc11c341690c35cdac17cfafb8d Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:36:53 +0100 Subject: [PATCH 444/658] add index helper func --- pixels/behaviours/virtual_reality.py | 36 ++++++---------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index aa11c69..1ffb2ca 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -772,39 +772,17 @@ def _first_post_mark(group_df, check_marks): # ------------------------------------------------------------------------- - # Utilities for detecting run‐start and run‐end of a boolean mask + # Run‐start / run‐end utilities for boolean masks # ------------------------------------------------------------------------- - @staticmethod - def _first_in_run(mask: np.ndarray) -> np.ndarray: - """ - True exactly at the first True of each contiguous run in `mask`. - """ - idx = np.flatnonzero(mask) - if idx.size == 0: - return np.zeros_like(mask) - # a run starts where current True differs from previous True - starts = np.concatenate([[True], - mask[idx[1:]] != mask[idx[:-1]]]) - first_idx = idx[starts] - out = np.zeros_like(mask) - out[first_idx] = True - return out + def _first_index(self, group: pd.DataFrame) -> int: + return group.index.min() - @staticmethod - def _last_in_run(mask: np.ndarray) -> np.ndarray: - """ - True exactly at the last True of each contiguous run in `mask`. - """ - idx = np.flatnonzero(mask) - if idx.size == 0: - return np.zeros_like(mask) - ends = np.concatenate([mask[idx[:-1]] != mask[idx[1:]], [True]]) - last_idx = idx[ends] - out = np.zeros_like(mask) - out[last_idx] = True - return out + def _last_index(self, group: pd.DataFrame) -> int: + return group.index.max() + def _get_index(self, df: pd.DataFrame, index) -> int: + return df.index.get_indexer(index) # ------------------------------------------------------------------------- # Outcome‐mapping machinery From 7152cc9ccd6df4aa4794fc0ba104124290d99c54 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:37:11 +0100 Subject: [PATCH 445/658] map outcome --- pixels/behaviours/virtual_reality.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 1ffb2ca..c9e88b4 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -785,27 +785,27 @@ def _get_index(self, df: pd.DataFrame, index) -> int: return df.index.get_indexer(index) # ------------------------------------------------------------------------- - # Outcome‐mapping machinery + # Outcome mapping # ------------------------------------------------------------------------- def _build_outcome_map(self) -> dict: - """ - Returns a dict mapping (Outcomes, Conditions) → ActionLabels. - """ m = { - (Outcomes.ABORTED_DARK, Conditions.DARK): ActionLabels.miss_dark, - (Outcomes.ABORTED_LIGHT, Conditions.LIGHT): ActionLabels.miss_light, - (Outcomes.PUNISHED, Conditions.LIGHT): ActionLabels.punished_light, - (Outcomes.PUNISHED, Conditions.DARK): ActionLabels.punished_dark, - (Outcomes.AUTO_LIGHT, Conditions.LIGHT): ActionLabels.auto_light, - (Outcomes.DEFAULT, Conditions.LIGHT): ActionLabels.default_light, - (Outcomes.REINF_LIGHT, Conditions.LIGHT): ActionLabels.reinf_light, - (Outcomes.AUTO_DARK, Conditions.DARK): ActionLabels.auto_dark, - (Outcomes.REINF_DARK, Conditions.DARK): ActionLabels.reinf_dark, - - # triggered must include light vs dark - (Outcomes.TRIGGERED, Conditions.LIGHT): ActionLabels.triggered_light, - (Outcomes.TRIGGERED, Conditions.DARK): ActionLabels.triggered_dark, + (Outcomes.ABORTED_LIGHT, Conditions.LIGHT): TrialTypes.miss_light, + (Outcomes.ABORTED_DARK, Conditions.DARK): TrialTypes.miss_dark, + + (Outcomes.NONE, Conditions.LIGHT): TrialTypes.punished_light, + (Outcomes.NONE, Conditions.DARK): TrialTypes.punished_dark, + + (Outcomes.DEFAULT, Conditions.LIGHT): TrialTypes.default_light, + + (Outcomes.AUTO_LIGHT, Conditions.LIGHT): TrialTypes.auto_light, + (Outcomes.AUTO_DARK, Conditions.DARK): TrialTypes.auto_dark, + + (Outcomes.REINF_LIGHT, Conditions.LIGHT): TrialTypes.reinf_light, + (Outcomes.REINF_DARK, Conditions.DARK): TrialTypes.reinf_dark, + + (Outcomes.TRIGGERED, Conditions.LIGHT): TrialTypes.triggered_light, + (Outcomes.TRIGGERED, Conditions.DARK): TrialTypes.triggered_dark, } return m From 7d213d17f0119fa86e5fd753e1bec05730921317 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:37:34 +0100 Subject: [PATCH 446/658] map outcome flags --- pixels/behaviours/virtual_reality.py | 69 ++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index c9e88b4..2b944b4 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -811,23 +811,56 @@ def _build_outcome_map(self) -> dict: def _compute_outcome_flag( self, - trial_df: pd.DataFrame, - outcome_map: dict - ) -> ActionLabels: - """ - Given one trial's DataFrame, look at its reward_type & trial_type - and return the matching ActionLabels member. - """ - rts = trial_df.reward_type.unique() - if rts.size == 0: - # no reward_type → unfinished or last‐trial abort - return ActionLabels.NONE - - rt = Outcomes(int(rts[0])) - cond = Conditions(int(trial_df.trial_type.iloc[0])) - key = (rt, cond) + trial_id: int, + trial_df: pd.DataFrame, + outcome_map: dict, + session, + ) -> TrialTypes: + + valve_events = {} + + # get non-zero reward types + reward_not_none = (trial_df.reward_type != Outcomes.NONE) + reward_typed = trial_df[reward_not_none] + + # get trial type + trial_type = int(trial_df.trial_type.iloc[0]) + + # get punished + punished = (trial_df.world_index == Worlds.WHITE) + + if (reward_typed.size == 0) & (not np.any(punished)): + # >>>> unfinished trial >>>> + # double check it is the last trial + assert (trial_df.position_in_tunnel.max()\ + < session.tunnel_reset) + logging.info(f"\n> trial {trial_id} is unfinished when session " + "ends, so there is no outcome.") + return TrialTypes.NONE, valve_events + # <<<< unfinished trial <<<< + elif (reward_typed.size == 0) & np.any(punished): + logging.info(f"\n> trial {trial_id} is punished.") + # get reward type zero for punished + reward_type = int(trial_df.reward_type.unique()) + else: + # get non-zero reward type in current trial + reward_type = int(reward_typed.reward_type.unique()) + + if reward_type > Outcomes.NONE: + # >>>> non aborted, valve events >>>> + # if not aborted, map valve open & closed + # map valve open + valve_open_t = reward_typed.index[0] + valve_events[Events.valve_open] = [valve_open_t] + # map valve closed + valve_closed_t = reward_typed.index[-1] + valve_events[Events.valve_closed] = [valve_closed_t] + # <<<< non aborted, valve events <<<< + + # build key for outcome_map + key = (reward_type, trial_type) + try: - return outcome_map[key] + return outcome_map[key], valve_events except KeyError: - raise PixelsError(f"No outcome mapping for {key}") -''' + raise PixelsError(f"No mapping for outcome {key}") From 79fe84cf904ea338bc5c439dfdcfc9125ae88041 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 4 Jun 2025 18:37:52 +0100 Subject: [PATCH 447/658] convert action labels to dictionary before saving --- pixels/stream.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 3146be5..b9ec596 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -395,11 +395,10 @@ def _sync_vr(self, vr): vr_session, synched_vr, ) + labels_dict = action_labels._asdict() np.savez_compressed( action_labels_path, - outcome=action_labels[:, 0], - events=action_labels[:, 1], - timestamps=action_labels[:, 2], + **labels_dict, ) logging.info(f"\n> Action labels saved to: {action_labels_path}.") From af58f60e76f2b2a00275107b46dc8fa76eddca3c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 12:39:25 +0100 Subject: [PATCH 448/658] use arg.name in the cache file name if available --- pixels/decorators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 3dda450..13b175e 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -38,7 +38,8 @@ def wrapper(*args, **kwargs): as_list[i] = name # build a key: method name + all args - key_parts = [method.__name__] + [str(i) for i in as_list] + key_parts = [method.__name__] + [str(i.name) if hasattr(i, "name") + else str(i) for i in as_list] cache_path = self.cache /\ ("_".join(key_parts) + f"_{self.stream_id}.h5") From 29a8192bcbb971fd057cecd91c11184b4de79576 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 12:40:07 +0100 Subject: [PATCH 449/658] add trial condition mask --- pixels/behaviours/virtual_reality.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 2b944b4..d726893 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -469,6 +469,7 @@ class LabeledEvents(NamedTuple): outcome: np.ndarray # shape (N,) dtype uint32 events: np.ndarray # shape (N,) dtype uint32 + class WorldMasks(NamedTuple): in_gray: pd.Series in_dark: pd.Series @@ -476,6 +477,12 @@ class WorldMasks(NamedTuple): in_light: pd.Series in_tunnel: pd.Series + +class ConditionMasks(NamedTuple): + light_trials: pd.Series + dark_trials: pd.Series + + class VR(Behaviour): """Behaviour subclass to extract action labels & events from vr_data.""" def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: @@ -498,6 +505,14 @@ def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: in_tunnel = in_tunnel, ) + + def _get_condition_masks(self, df: pd.DataFrame) -> ConditionMasks: + return ConditionMasks( + light_trials = (df.trial_type == Conditions.LIGHT), + dark_trials = (df.trial_type == Conditions.DARK), + ) + + def _extract_action_labels( self, session, From 819511a055b048605af85698928d23870497828e Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 12:40:56 +0100 Subject: [PATCH 450/658] get world index based events to fetch dark on events for pre_dark_end --- pixels/behaviours/virtual_reality.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index d726893..697acc9 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -529,12 +529,21 @@ def _extract_action_labels( outcomes_arr = np.zeros(N, dtype=np.uint32) # world index based events - for event, idx in self._world_event_indices(data).items(): + world_index_based_events = self._world_event_indices(data) + for event, idx in world_index_based_events.items(): mask = self._get_index(data, idx.to_numpy()) self._stamp_mask(events_arr, mask, event) + # get dark onset times for pre_dark_len labels + dark_on_t = world_index_based_events[Events.dark_on] + # positional events (pre‐dark end, landmarks, reward‐zone) - for event, idx in self._position_event_indices(session, data).items(): + pos_based_events = self._position_event_indices( + session, + data, + dark_on_t, + ) + for event, idx in pos_based_events.items(): mask = self._get_index(data, idx.to_numpy()) self._stamp_mask(events_arr, mask, event) From bb4ae71dbe00003fd01efe5dd51a5b7fb3eb9345 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 20:46:53 +0100 Subject: [PATCH 451/658] print name of label --- pixels/behaviours/base.py | 2 +- pixels/stream.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index db6bb51..e889c91 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2649,7 +2649,7 @@ def get_positional_data( logging.info( f"\n> Getting positional neural data of {units} units in " - f"<{label}> trials." + f"<{label.name}> trials." ) output[stream_id] = stream.get_positional_data( units=units, # NOTE: put units first! diff --git a/pixels/stream.py b/pixels/stream.py index b9ec596..726d80b 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -83,7 +83,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): if "trial_rate" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " - f"<{label}> trials." + f"<{label.name}> trials." ) return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, From e92cf2398ca76ebef6bf09a8e41712afa2941a97 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 20:47:45 +0100 Subject: [PATCH 452/658] adjust starting locations if trial not aligned to starts --- pixels/pixels_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index aa0c725..58c8646 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1246,12 +1246,12 @@ def _get_vr_positional_neural_data(event, positions, data_type, data): # get constants from vd from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END,\ PRE_DARK_LEN - from pixels.behaviours.virtual_reality import Events # get the starting index for each trial (column) starts = positions.iloc[0, :].astype(int) - # if align to dark_onset, actual starting position is before that - if event == Events.dark_on: + # NOTE: if align to dark_onset or end of pre dark, actual starting position + # is before that + if np.isin(["dark_on", "pre_dark_end"], event.name).any(): starts = starts - PRE_DARK_LEN # create position indices indices = np.arange(0, TUNNEL_RESET+1) From c9a29b9965caf7210f4afb5835d7d1ceacc9547d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 20:48:25 +0100 Subject: [PATCH 453/658] drop nans if in all columns --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 58c8646..88645b4 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1329,7 +1329,7 @@ def _get_vr_positional_neural_data(event, positions, data_type, data): axis=1, level=["unit", "start", "trial"], ascending=[True, False, True], - ) + ).dropna(how="all") return pos_data, occupancy From fbaf37789632fb214f0a0e41768b411a3cae6ba1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 20:49:27 +0100 Subject: [PATCH 454/658] make sure name is before doc --- pixels/behaviours/virtual_reality.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 697acc9..ca41a83 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -583,16 +583,18 @@ def _extract_action_labels( # ------------------------------------------------------------------------- @staticmethod - """Bitwise‐OR `flag` into `storage` at every True in `mask`.""" def _stamp_mask(array: np.ndarray, mask: np.ndarray, flag: IntFlag): + """ + Bitwise‐OR `flag` into `storage` at every True in `mask`. + """ np.bitwise_or.at(array, mask, flag) @staticmethod + def _stamp_rising(array: np.ndarray, signal: np.ndarray, flag: IntFlag): """ Find rising‐edge frames in a 0/1 `signal` array (diff == +1) and stamp `flag` at those indices. """ - def _stamp_rising(array: np.ndarray, signal: np.ndarray, flag: IntFlag): # extract edges edges = np.flatnonzero(np.diff(signal, prepend=0) == 1) np.bitwise_or.at(array, edges, flag) From f67fe5a75ad693f88e57251b2e38de063243d4a1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 20:50:23 +0100 Subject: [PATCH 455/658] make sure for dark trials, pre dark end is the same frame as dark onset --- pixels/behaviours/virtual_reality.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index ca41a83..e514228 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -687,7 +687,8 @@ def _world_event_indices( def _position_event_indices( self, session, - df: pd.DataFrame + df: pd.DataFrame, + dark_on_t, ) -> dict[Events, pd.Series]: masks: dict[Events, pd.Series] = {} @@ -714,15 +715,30 @@ def _first_post_mark(group_df, check_marks): trial_starts = in_tunnel_trials.apply(self._first_index) masks[Events.trial_start] = trial_starts + # NOTE: dark trials should in theory have EQUAL index pre_dark_end_t + # and dark_on, BUT! after interpolation, some dark trials have their + # dark onset earlier than expected, those are corrected to the first + # expected position, they will have the same index as pre_dark_end_t. + # others will not since their world_index change later than expected. + # SO! to keep it consistent, for dark trials, pre_dark_end will be the + # SAME frame as dark onsets. + # get starting positions of all trials start_pos = in_tunnel_trials["position_in_tunnel"].first() - # plus pre dark - end_pre_dark = start_pos + session.pre_dark_len - pre_dark_end_t = in_tunnel_trials.apply( - lambda df: _first_post_mark(df, end_pre_dark) + # get light trials + lights = self._get_condition_masks(df).light_trials + light_trials = df[in_tunnel & lights].groupby("trial_count") + # get starting positions of light trials plus pre dark length + light_pre_dark_len = light_trials["position_in_tunnel"].first()\ + + session.pre_dark_len + light_pre_dark_end_t = light_trials.apply( + lambda df: _first_post_mark(df, light_pre_dark_len) ).dropna().astype(int) + # concat dark and light trials + pre_dark_end_t = pd.concat([dark_on_t, light_pre_dark_end_t]) + masks[Events.pre_dark_end] = pre_dark_end_t # >>> distance travelled before dark onset per trial >>> From daedfc1a075fb0ff9fd273a52895f2b4e86eece5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 5 Jun 2025 20:58:59 +0100 Subject: [PATCH 456/658] print name of intflag, not value --- pixels/stream.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 726d80b..0a0ede1 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -109,8 +109,8 @@ def _get_aligned_trials( # get action and event label file outcomes = action_labels["outcome"] events = action_labels["events"] - # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to convert - # it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE + # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to + # convert it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE timestamps = action_labels["timestamps"] # select frames of wanted trial type @@ -129,7 +129,7 @@ def _get_aligned_trials( if selected_starts.size == 0: logging.info(f"\n> No trials found with label {label} and event " - f"{event}, output will be empty.") + f"{event.name}, output will be empty.") return None # use original trial id as trial index @@ -425,7 +425,7 @@ def get_binned_trials( time_bin=None, pos_bin=None ): # define output path for binned spike rate - output_path = self.cache/ f"{self.session.name}_{label}_{units}_"\ + output_path = self.cache/ f"{self.session.name}_{label.name}_{units}_"\ f"{time_bin}_{pos_bin}cm_{self.stream_id}.npz" binned = self._bin_aligned_trials( label=label, @@ -456,12 +456,12 @@ def _bin_aligned_trials( ) if trials is None: - logging.info(f"\n> No trials found with label {label} and event " - f"{event}, output will be empty.") + logging.info(f"\n> No trials found with label {label.name} and " + f"event {event.name}, output will be empty.") return None logging.info( - f"\n> Binning <{label}> trials from {self.stream_id} " + f"\n> Binning <{label.name}> trials from {self.stream_id} " f"in {units}." ) From 6f34e62b0ec200b80152048ed502fb04a976467a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:22:37 +0100 Subject: [PATCH 457/658] pass vr session object, not vr object to synchronisation --- pixels/behaviours/base.py | 6 +++--- pixels/experiment.py | 9 +++++++-- pixels/stream.py | 11 +++++------ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e889c91..50b61d0 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2797,13 +2797,13 @@ def save_spike_chance(self, spiked, sigma, sample_rate): return None - def sync_vr(self, vr): + def sync_vr(self, vr_session): """ Synchronise each pixels stream with virtual reality data. params === - vr: class, virtual reality object. + vr: class, virtual reality session object. """ streams = self.files["pixels"] @@ -2818,7 +2818,7 @@ def sync_vr(self, vr): logging.info( f"\n> Synchonising pixels data with vr." ) - stream.sync_vr(vr) + stream.sync_vr(vr_session) return None diff --git a/pixels/experiment.py b/pixels/experiment.py index 11c3173..8ff299d 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -589,10 +589,15 @@ def get_binned_trials(self, *args, units=None, **kwargs): def sync_vr(self, vr): """ - Synchronise virtual reality data with pixels streams. + Synchronise virtual reality data of a mouse (or mice) with pixels + streams. """ trials = {} for i, session in enumerate(self.sessions): - session.sync_vr(vr) + # vr is a vision-in-darkness mouse object + vr_session = vr.sessions[i] + assert session.name.split("_")[0] in vr_session.name + + session.sync_vr(vr_session) return None diff --git a/pixels/stream.py b/pixels/stream.py index 0a0ede1..8fe2d62 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -305,7 +305,7 @@ def get_spike_times(self, units): return spike_times - def sync_vr(self, vr): + def sync_vr(self, vr_session): # get action labels & synched vr path action_labels = self.session.get_action_labels()[self.stream_num] synched_vr_path = self.session.find_file( @@ -315,20 +315,19 @@ def sync_vr(self, vr): logging.info(f"\n> {self.stream_id} from {self.session.name} is " "already synched with vr.") else: - self._sync_vr(vr) + self._sync_vr(vr_session) return None - def _sync_vr(self, vr): + def _sync_vr(self, vr_session): # get spike data spike_data = self.session.find_file( name=self.files["ap_raw"][self.stream_num], copy=True, ) - # get vr data - vr_session = vr.sessions[0] + # get synchronised vr path synched_vr_path = vr_session.cache_dir + "synched/" +\ vr_session.name + "_vr_synched.h5" @@ -373,7 +372,7 @@ def _sync_vr(self, vr): # convert value into their index to calculate all timestamps pixels_idx = np.arange(pixels_syncs.shape[0]) - synched_vr = vr.sync_streams( + synched_vr = vr_session.sync_streams( self.BEHAVIOUR_SAMPLE_RATE, pixels_vr_edges, pixels_idx, From 6b39c1b125211a938e743b763a86ffd23f8d9663 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:23:15 +0100 Subject: [PATCH 458/658] print name --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 50b61d0..44875c7 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2816,7 +2816,7 @@ def sync_vr(self, vr_session): ) logging.info( - f"\n> Synchonising pixels data with vr." + f"\n> Synchonising {self.name} {stream_id} pixels data with vr." ) stream.sync_vr(vr_session) From 848d71d5fca2a28b3475be8944809ba034787af7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:25:43 +0100 Subject: [PATCH 459/658] map actual starting positions if trials not aligned to starts --- pixels/stream.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 8fe2d62..9a1fe46 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -133,7 +133,31 @@ def _get_aligned_trials( return None # use original trial id as trial index - trial_ids = synched_vr.iloc[selected_starts].trial_count.unique() + trial_ids = pd.Index( + synched_vr.iloc[selected_starts].trial_count.unique() + ) + + # map actual starting locations + if not "trial_start" in event.name: + from pixels.behaviours.virtual_reality import Events + all_start_idx = np.where( + np.bitwise_and(events, Events.trial_start) + )[0] + start_idx = trials[np.where( + np.isin(trials, all_start_idx) + )[0]] + else: + start_idx = selected_starts.copy() + + start_pos = synched_vr.position_in_tunnel.iloc[ + start_idx + ].values.astype(int) + + # create multiindex with starts + cols_with_starts = pd.MultiIndex.from_arrays( + [start_pos, trial_ids], + names=("start", "trial"), + ) # pad ends with 1 second extra to remove edge effects from # convolution From fbbd95a1608515d1c9652f0645e27b1f4c7b030a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:26:08 +0100 Subject: [PATCH 460/658] add start as top level on positions --- pixels/pixels_utils.py | 9 ++++++++- pixels/stream.py | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 88645b4..d727520 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1316,7 +1316,14 @@ def _get_vr_positional_neural_data(event, positions, data_type, data): trial_level = pos_data.columns.get_level_values("trial") unit_level = pos_data.columns.get_level_values("unit") # map start level - start_level = trial_level.map(starts) + starts = positions.columns.get_level_values("start").values + start_series = pd.Series( + data=starts, + index=trial_ids, + name="start", + ) + start_level = trial_level.map(start_series) + # define new columns new_cols = pd.MultiIndex.from_arrays( [start_level, unit_level, trial_level], diff --git a/pixels/stream.py b/pixels/stream.py index 9a1fe46..b82089f 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -248,6 +248,8 @@ def _get_aligned_trials( level="trial", return_format="dataframe", ) + positions.columns = cols_with_starts + positions = positions.sort_index(axis=1, ascending=[False, True]) # get trials vertically stacked spiked stacked_spiked = pd.concat( From a794f1658db3708eee42cabd5aedd6078ee9e03d Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:27:54 +0100 Subject: [PATCH 461/658] no need to calculate starts cuz it's already a level in position --- pixels/pixels_utils.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d727520..680e825 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1244,15 +1244,8 @@ def _get_vr_positional_neural_data(event, positions, data_type, data): logging.info(f"\n> Getting positional {data_type}...") # get constants from vd - from vision_in_darkness.constants import TUNNEL_RESET, ZONE_END,\ - PRE_DARK_LEN - - # get the starting index for each trial (column) - starts = positions.iloc[0, :].astype(int) - # NOTE: if align to dark_onset or end of pre dark, actual starting position - # is before that - if np.isin(["dark_on", "pre_dark_end"], event.name).any(): - starts = starts - PRE_DARK_LEN + from vision_in_darkness.constants import TUNNEL_RESET + # create position indices indices = np.arange(0, TUNNEL_RESET+1) # create occupancy array for trials From 05b9ba40dbc0a35fe63c2c0fa037ecb32c86ecfe Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:28:27 +0100 Subject: [PATCH 462/658] loop over trial ids and use xs to get trial positions --- pixels/pixels_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 680e825..17ae33c 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1255,9 +1255,10 @@ def _get_vr_positional_neural_data(event, positions, data_type, data): ) pos_data = {} - for t, trial in enumerate(positions): + trial_ids = positions.columns.get_level_values("trial") + for t, trial in enumerate(trial_ids): # get trial position - trial_pos = positions[trial].dropna() + trial_pos = positions.xs(trial, level="trial", axis=1).dropna() # floor position and set to int trial_pos = trial_pos.apply(lambda x: np.floor(x)).astype(int) From e8cf96b0ed8db1545b7c93e5ff82b2aedeff65a0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:28:52 +0100 Subject: [PATCH 463/658] formatting --- pixels/pixels_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 17ae33c..765ad98 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1288,7 +1288,11 @@ def _get_vr_positional_neural_data(event, positions, data_type, data): elif data_type == "spiked": # group values by position and get sum data how = "sum" - grouped_data = math_utils.group_and_aggregate(trial_data, "position", how) + grouped_data = math_utils.group_and_aggregate( + trial_data, + "position", + how, + ) # reindex into full tunnel length pos_data[trial] = grouped_data.reindex(indices) From e76f6237f65084b79b61c398c2f06086ad3985f2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:29:08 +0100 Subject: [PATCH 464/658] use double quote --- pixels/stream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index b82089f..214f3c0 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -404,10 +404,9 @@ def _sync_vr(self, vr_session): pixels_idx, )[vr_session.name] - # save to pixels processed dir file_utils.write_hdf5( self.processed /\ - self.behaviour_files['vr_synched'][self.stream_num], + self.behaviour_files["vr_synched"][self.stream_num], synched_vr, ) From 92cc3d7861f2fa32ef581b528c3c503086808c42 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 6 Jun 2025 13:30:07 +0100 Subject: [PATCH 465/658] no need to index into dict cuz vr_session is a session object --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 214f3c0..fc8da1f 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -402,7 +402,7 @@ def _sync_vr(self, vr_session): self.BEHAVIOUR_SAMPLE_RATE, pixels_vr_edges, pixels_idx, - )[vr_session.name] + ) file_utils.write_hdf5( self.processed /\ From 92cd575f8d3b5f7e2dd6408677471a2dfdbcd897 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 16:47:10 +0100 Subject: [PATCH 466/658] print label name rather than value --- pixels/behaviours/base.py | 6 +++--- pixels/stream.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 44875c7..b624b74 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1964,7 +1964,7 @@ def align_trials( if "trial" in data: logging.info( f"\n> Aligning {self.name} {data} of {units} units to label " - f"<{label}> trials." + f"<{label.name}> trials." ) return self._get_aligned_trials( label, event, data=data, units=units, sigma=sigma, @@ -2691,7 +2691,7 @@ def get_binned_trials( ) logging.info( - f"\n> Getting binned <{label}> trials from {stream_id} " + f"\n> Getting binned <{label.name}> trials from {stream_id} " f"in {units}." ) binned[stream_id] = stream.get_binned_trials( @@ -2845,7 +2845,7 @@ def get_spatial_psd( logging.info( f"\n> Getting spatial PSD of {units} units in " - f"<{label}> trials." + f"<{label.name}> trials." ) output[stream_id] = stream.get_spatial_psd( units=units, # NOTE: put units first! diff --git a/pixels/stream.py b/pixels/stream.py index fc8da1f..7017fed 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -449,7 +449,7 @@ def get_binned_trials( time_bin=None, pos_bin=None ): # define output path for binned spike rate - output_path = self.cache/ f"{self.session.name}_{label.name}_{units}_"\ + output_path = self.cache/ f"{self.session.name}_{units}_{label.name}_"\ f"{time_bin}_{pos_bin}cm_{self.stream_id}.npz" binned = self._bin_aligned_trials( label=label, From fc3ff9c14f6fb419b1752dcb93d41ced4389b5e0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 16:47:33 +0100 Subject: [PATCH 467/658] take in binning args to allow psd for binned data --- pixels/behaviours/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b624b74..75ef410 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2825,6 +2825,7 @@ def sync_vr(self, vr_session): def get_spatial_psd( self, label, event, end_event=None, sigma=None, units=None, + crop_from=None, use_binned=False, time_bin=None, pos_bin=None, ): """ Get spatial power spectral density of selected units. @@ -2848,11 +2849,15 @@ def get_spatial_psd( f"<{label.name}> trials." ) output[stream_id] = stream.get_spatial_psd( - units=units, # NOTE: put units first! + units=units, label=label, event=event, sigma=sigma, - end_event=end_event, # NOTE: put units last! + end_event=end_event, + crop_from=crop_from, + use_binned=use_binned, + time_bin=time_bin, + pos_bin=pos_bin, ) return output From 9611ab499a4a13dd25b9c2d714ea727f929dc0f5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:43:44 +0100 Subject: [PATCH 468/658] make sure the name is not the same --- pixels/behaviours/virtual_reality.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index e514228..75ba766 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -774,12 +774,12 @@ def _first_post_mark(group_df, check_marks): landmark_idx = l // 2 + 1 # even idx on, odd idx off - landmark_on = landmark - landmark_off = landmarks[l + 1] + on_landmark = landmark + off_landmark = landmarks[l + 1] in_landmark = ( - (df.position_in_tunnel >= landmark_on) & - (df.position_in_tunnel <= landmark_off) + (df.position_in_tunnel >= on_landmark) & + (df.position_in_tunnel <= off_landmark) ) landmark_on = df[in_landmark].groupby("trial_count").apply( From 795d7240a7027566930e9bd017267e82b2cf856a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:44:30 +0100 Subject: [PATCH 469/658] potentially checking for multiple crossing of threshold --- pixels/behaviours/virtual_reality.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 75ba766..dc6eb27 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -738,7 +738,6 @@ def _first_post_mark(group_df, check_marks): # concat dark and light trials pre_dark_end_t = pd.concat([dark_on_t, light_pre_dark_end_t]) - masks[Events.pre_dark_end] = pre_dark_end_t # >>> distance travelled before dark onset per trial >>> @@ -818,8 +817,20 @@ def _first_post_mark(group_df, check_marks): # ------------------------------------------------------------------------- def _first_index(self, group: pd.DataFrame) -> int: + idx = group.index + early_idx = idx[:len(idx) // 2] + + # double check the last index is the first time reaching that point + idx_discontinued = (np.diff(early_idx) > 1) + #if np.any(idx_discontinued): + # print("discontinued") + # assert 0 + # last_disc = np.where(idx_discontinued)[0][-1] + # return group.iloc[last_disc:].index.min() + #else: return group.index.min() + def _last_index(self, group: pd.DataFrame) -> int: return group.index.max() From 912626361105c5aaedb4f6c88b014dd12bdc49e4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:45:21 +0100 Subject: [PATCH 470/658] take last 1/4 of data check for discontinunity --- pixels/behaviours/virtual_reality.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index dc6eb27..1c12dab 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -832,7 +832,21 @@ def _first_index(self, group: pd.DataFrame) -> int: def _last_index(self, group: pd.DataFrame) -> int: - return group.index.max() + # only check the second half in case the discontinuity happened at the + # beginning + idx = group.index + late_idx = idx[-len(idx)//4:] + + # double check the last index is the first time reaching that point + idx_discontinued = (np.diff(late_idx) > 1) + if np.any(idx_discontinued): + logging.warning("\n> index discontinued.") + print(group) + first_disc = np.where(idx_discontinued)[0][0] + return late_idx[first_disc] + else: + return idx.max() + def _get_index(self, df: pd.DataFrame, index) -> int: return df.index.get_indexer(index) From e25d3b92c56aebf03c361c7196ecb759d480fc5e Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:46:27 +0100 Subject: [PATCH 471/658] move behaviour dict up so that vr_synched & action label can be appended along with recs --- pixels/ioutils.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 987cb07..b152fc9 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -76,6 +76,14 @@ def get_data_files(data_dir, session_name): if not ap_meta: raise PixelsError(f"{session_name}: could not find raw AP metadata file.") + pupil_raw = sorted(glob.glob(f"{data_dir}/behaviour/pupil_cam/*.avi*")) + + behaviour = { + "vr_synched": [], + "action_labels": [], + "pupil_raw": pupil_raw, + } + pixels = {} for r, rec in enumerate(ap_raw): stream_id = rec[-12:-4] @@ -98,6 +106,13 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["ap_raw"].append(base_name) pixels[stream_id]["ap_meta"].append(original_name(ap_meta[r])) + behaviour["vr_synched"].append(base_name.with_name( + f"{session_name}_{probe_id}_vr_synched.h5" + )) + behaviour["action_labels"].append(base_name.with_name( + f"action_labels_{probe_id}.npz" + )) + # >>> spikeinterface cache >>> # extracted & motion corrected ap stream, 300Hz+ pixels[stream_id]["ap_motion_corrected"] = base_name.with_name( @@ -175,19 +190,6 @@ def get_data_files(data_dir, session_name): # f"spike_rate_{stream_id}.h5" #) - pupil_raw = sorted(glob.glob(f"{data_dir}/behaviour/pupil_cam/*.avi*")) - - behaviour = { - "vr_synched": [], - "action_labels": [], - "pupil_raw": pupil_raw, - } - - behaviour["vr_synched"].append(base_name.with_name( - f"{session_name}_vr_synched.h5" - )) - behaviour["action_labels"].append(base_name.with_name(f"action_labels.npz")) - if pupil_raw: behaviour["pupil_processed"] = [] behaviour["motion_index"] = [] From d77dccaabeb7f528ec692b5bf01204100c3cbbcc Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:47:19 +0100 Subject: [PATCH 472/658] remove number, keep unit --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 765ad98..1bf48a2 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1124,7 +1124,7 @@ def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, # set position index too positions.index = data.index - # resample to 100ms bin, and get position mean + # resample to ms bin, and get position mean mean_pos = positions.resample(time_bin).mean() if bin_method == "sum": From 52756aece6bd288ec7248ccd194d1a2b1446fd7b Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:47:40 +0100 Subject: [PATCH 473/658] remove redundant arg --- pixels/pixels_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1bf48a2..1ea2443 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1186,7 +1186,7 @@ def correct_group_id(rec): return group_ids -def get_vr_positional_data(event, trial_data): +def get_vr_positional_data(trial_data): """ Get positional firing rate and spike count for VR behaviour. @@ -1215,7 +1215,7 @@ def get_vr_positional_data(event, trial_data): return {"pos_fr": pos_fr, "pos_fc": pos_fc, "occupancy": occupancy} -def _get_vr_positional_neural_data(event, positions, data_type, data): +def _get_vr_positional_neural_data(positions, data_type, data): """ Get positional neural data for VR behaviour. From 8803aa09295118f9a0b85e607c5184255560198b Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:48:32 +0100 Subject: [PATCH 474/658] add notes; get occupancy with spike count --- pixels/pixels_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1ea2443..75b27c2 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1199,20 +1199,20 @@ def get_vr_positional_data(trial_data): dict, positional firing rate, positional spike count, positional occupancy, data in 1cm resolution. """ - pos_fr, occupancy = _get_vr_positional_neural_data( - event=event, - positions=trial_data["positions"], - data_type="spike_rate", - data=trial_data["fr"], - ) - pos_fc, _ = _get_vr_positional_neural_data( - event=event, + # NOTE: take occupancy from spike count since in we might need to + # interpolate fr for binned data + pos_fc, occupancy = _get_vr_positional_neural_data( positions=trial_data["positions"], data_type="spiked", data=trial_data["spiked"], ) + pos_fr, _ = _get_vr_positional_neural_data( + positions=trial_data["positions"], + data_type="spike_rate", + data=trial_data["fr"], + ) - return {"pos_fr": pos_fr, "pos_fc": pos_fc, "occupancy": occupancy} + return {"pos_fc": pos_fc, "pos_fr": pos_fr, "occupancy": occupancy} def _get_vr_positional_neural_data(positions, data_type, data): From b39f45a894cb4e4b226543f1d14b8ed729853303 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:49:59 +0100 Subject: [PATCH 475/658] allows data to be binned too --- pixels/pixels_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 75b27c2..38c6fd0 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1242,6 +1242,19 @@ def _get_vr_positional_neural_data(positions, data_type, data): shape: position x trial """ logging.info(f"\n> Getting positional {data_type}...") + if "bin" in positions.index.name: + logging.info(f"\n> Getting binned positional {data_type}...") + # create position indices for binned data + indices_range = [ + positions.min().min(), + positions.max().max()+1, + ] + else: + logging.info(f"\n> Getting positional {data_type}...") + # get constants from vd + from vision_in_darkness.constants import TUNNEL_RESET + # create position indices + indices_range = [0, TUNNEL_RESET+1] # get constants from vd from vision_in_darkness.constants import TUNNEL_RESET From 18874f5bb36a982c120badef13be41e930de3d4d Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:50:56 +0100 Subject: [PATCH 476/658] initiate occupancy as df, not np array --- pixels/pixels_utils.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 38c6fd0..c127e57 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1256,15 +1256,16 @@ def _get_vr_positional_neural_data(positions, data_type, data): # create position indices indices_range = [0, TUNNEL_RESET+1] - # get constants from vd - from vision_in_darkness.constants import TUNNEL_RESET + # get trial ids + trial_ids = positions.columns.get_level_values("trial") # create position indices - indices = np.arange(0, TUNNEL_RESET+1) + indices = np.arange(*indices_range).astype(int) # create occupancy array for trials - occupancy = np.full( - (TUNNEL_RESET+1, positions.shape[1]), - np.nan, + occupancy = pd.DataFrame( + data=np.full((len(indices), positions.shape[1]), np.nan), + index=indices, + columns=trial_ids, ) pos_data = {} @@ -1311,16 +1312,10 @@ def _get_vr_positional_neural_data(positions, data_type, data): pos_data[trial] = grouped_data.reindex(indices) # get trial occupancy pos_count = trial_data.groupby("position").size() - occupancy[pos_count.index.values, t] = pos_count.values + occupancy.loc[pos_count.index.values, trial] = pos_count.values # concatenate dfs pos_data = pd.concat(pos_data, axis=1, names=["trial", "unit"]) - # convert to df - occupancy = pd.DataFrame( - data=occupancy, - index=indices, - columns=positions.columns, - ) # add another level of starting position # group trials by their starting index From 167bb77ef5ea5138adf17fd5fa29bf8f78f9e866 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:51:54 +0100 Subject: [PATCH 477/658] only floor if position values are float --- pixels/pixels_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c127e57..9d75b3f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1241,7 +1241,8 @@ def _get_vr_positional_neural_data(positions, data_type, data): occupancy: pandas df, count of each position. shape: position x trial """ - logging.info(f"\n> Getting positional {data_type}...") + from pandas.api.types import is_integer_dtype + if "bin" in positions.index.name: logging.info(f"\n> Getting binned positional {data_type}...") # create position indices for binned data @@ -1269,16 +1270,16 @@ def _get_vr_positional_neural_data(positions, data_type, data): ) pos_data = {} - trial_ids = positions.columns.get_level_values("trial") for t, trial in enumerate(trial_ids): # get trial position trial_pos = positions.xs(trial, level="trial", axis=1).dropna() - # floor position and set to int - trial_pos = trial_pos.apply(lambda x: np.floor(x)).astype(int) - - # exclude positions after tunnel reset - trial_pos = trial_pos[trial_pos <= TUNNEL_RESET] + # convert to int if float + if not trial_pos.dtypes.map(is_integer_dtype).all(): + # floor position and set to int + trial_pos = trial_pos.apply(lambda x: np.floor(x)).astype(int) + # exclude positions after tunnel reset + trial_pos = trial_pos[trial_pos <= indices[-1]] # get firing rates for current trial of all units trial_data = data.xs( From 659a8f38495aa3b95bb5d91960a5db94f0ee2a5c Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:52:25 +0100 Subject: [PATCH 478/658] interpolate if binned data has missing values --- pixels/pixels_utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9d75b3f..d4f9393 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1310,7 +1310,26 @@ def _get_vr_positional_neural_data(positions, data_type, data): ) # reindex into full tunnel length - pos_data[trial] = grouped_data.reindex(indices) + reidxed = grouped_data.reindex(indices) + + # check for missing values in binned data + if ("bin" in positions.index.name) and (data_type == "spike_rate"): + # remove alll nan before data actually starts + start_idx = grouped_data.index[0] + chunk_data = reidxed.loc[start_idx:, :] + nan_check = chunk_data.isna().any().any() + if nan_check: + # interpolate missing fr + logging.info(f"\n> trial {trial} has missing values, " + "do linear interpolation.") + reidxed.loc[start_idx:, :] = chunk_data.interpolate( + method="linear", + axis=0, + ) + + # save to dict + pos_data[trial] = reidxed + # get trial occupancy pos_count = trial_data.groupby("position").size() occupancy.loc[pos_count.index.values, trial] = pos_count.values From b707ef267b9c860baf24867a66fd0fee11333647 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:52:47 +0100 Subject: [PATCH 479/658] drop nan in occupancy too to save space --- pixels/pixels_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d4f9393..ddca528 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1364,6 +1364,8 @@ def _get_vr_positional_neural_data(positions, data_type, data): ascending=[True, False, True], ).dropna(how="all") + occupancy = occupancy.dropna(how="all") + return pos_data, occupancy From bee840a166f5433117e5b80025ef1f3aeb3bf538 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:53:14 +0100 Subject: [PATCH 480/658] add note --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 7017fed..dcfbd28 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -476,7 +476,7 @@ def _bin_aligned_trials( label=label, event=event, sigma=sigma, - end_event=end_event, + end_event=end_event, # NOTE: ALWAYS the last arg ) if trials is None: From 3677cda86e0a8dee660cf6ba303ba399d6e567bc Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:53:35 +0100 Subject: [PATCH 481/658] get trial ids then loop cuz now cols r multilevel index --- pixels/stream.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index dcfbd28..23f802b 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -503,10 +503,12 @@ def _bin_aligned_trials( bin_arr = {} binned_count = {} binned_fr = {} - for trial in positions.columns.unique(): + + trial_ids = positions.columns.get_level_values("trial").unique() + for trial in trial_ids: counts = spiked.xs(trial, level="trial", axis=1).dropna() rates = fr.xs(trial, level="trial", axis=1).dropna() - trial_pos = positions[trial].dropna() + trial_pos = positions.xs(trial, level="trial", axis=1).dropna() # get bin spike count binned_count[trial] = xut.bin_vr_trial( From 2afc690f14a9f56f47d595b4e25df8d1aea7ff2d Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:54:30 +0100 Subject: [PATCH 482/658] add start level to binned data too --- pixels/stream.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 23f802b..1c48ced 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -544,13 +544,28 @@ def _bin_aligned_trials( logging.info(f"\n> Output saved at {output_path}.") # extract binned data in df format - bin_fc, bin_pos = self._extract_binned_data(binned_count) - bin_fr, _ = self._extract_binned_data(binned_fr) + bin_fc, bin_pos = self._extract_binned_data( + binned_count, + positions.columns, + ) + bin_fr, _ = self._extract_binned_data( + binned_fr, + positions.columns, + ) + + # convert it to binned positional data + pos_data = xut.get_vr_positional_data( + { + "positions": bin_pos.bin_pos, + "fr": bin_fr, + "spiked": bin_fc, + }, + ) - return {"bin_fc": bin_fc, "bin_fr": bin_fr, "bin_pos": bin_pos} + return pos_data - def _extract_binned_data(self, binned_data): + def _extract_binned_data(self, binned_data, pos_cols): """ """ df = ioutils.reindex_by_longest( @@ -570,6 +585,19 @@ def _extract_binned_data(self, binned_data): ) pos.columns.names = ["pos_type", "trial"] + # convert columns to df + pos_col_df = pos.columns.to_frame(index=False) + start_trial_df = pos_cols.to_frame(index=False) + # merge columns + merged_cols = pd.merge(pos_col_df, start_trial_df, on="trial") + + # create new columns + new_cols = pd.MultiIndex.from_frame( + merged_cols[["pos_type", "start", "trial"]], + names=["pos_type", "start", "trial"], + ) + pos.columns = new_cols + return data, pos From 1a4328f550ad7578d6372809c21708b51ab94525 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:54:59 +0100 Subject: [PATCH 483/658] add notes and remove redundant arg --- pixels/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 1c48ced..d1db231 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -620,11 +620,11 @@ def get_positional_data( label=label, event=event, sigma=sigma, - end_event=end_event, + end_event=end_event, # NOTE: ALWAYS the last arg ) # get positional spike rate, spike count, and occupancy - positional_data = xut.get_vr_positional_data(event, trials) + positional_data = xut.get_vr_positional_data(trials) return positional_data From 8ab87ecba4f4f2240dbc1f3b8123f6b8a20121e2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 13 Jun 2025 19:55:24 +0100 Subject: [PATCH 484/658] add data binning args to allow psd for binned data --- pixels/stream.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index d1db231..42e922d 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -773,6 +773,7 @@ def save_spike_chance(self, spiked, sigma): def get_spatial_psd( self, label, event, end_event=None, sigma=None, units=None, + crop_from=None, use_binned=False, time_bin=None, pos_bin=None, ): """ Get spatial power spectral density of selected units. @@ -781,15 +782,28 @@ def get_spatial_psd( # always put units first, cuz it is like that in # experiemnt.align_trials, otherwise the same cache cannot be loaded - from vision_in_darkness.constants import landmarks # get aligned firing rates and positions - trials = self.get_positional_data( - units=units, # NOTE: ALWAYS the first arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, - ) + if not use_binned: + trials = self.get_positional_data( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + ) + crop_from = crop_from + else: + trials = self.get_binned_trials( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + ) + crop_from = crop_from // pos_bin + 1 + # get positional fr pos_fr = trials["pos_fr"] @@ -797,14 +811,10 @@ def get_spatial_psd( psds = {} for s, start in enumerate(starts): data = pos_fr.xs(start, level="start", axis=1) - # remove black wall and post last landmark - cropped = data.loc[landmarks[0]:landmarks[-1], :] - # TODO may 30 2025: - # only remove 60cm black wall in light, remove first 50cm of tunnel - # anyways in dark! + # crop if needed + cropped = data.loc[crop_from:, :] # get power spectral density - #psds[start] = xut.get_spatial_psd(data) psds[start] = xut.get_spatial_psd(cropped) psd_df = pd.concat( From ee394a24e00007ba70c477c000081b097ae8c2e2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:34:50 +0100 Subject: [PATCH 485/658] make sure unit is the first arg --- pixels/behaviours/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 75ef410..e2cf8ee 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2695,9 +2695,9 @@ def get_binned_trials( f"in {units}." ) binned[stream_id] = stream.get_binned_trials( + units=units, # NOTE: always the first arg! label=label, event=event, - units=units, sigma=sigma, end_event=end_event, time_bin=time_bin, From a35ef682a080e03d5da5c77cae47dc91faeae6ea Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:35:02 +0100 Subject: [PATCH 486/658] implement chance at stream level --- pixels/behaviours/base.py | 65 +++++++++------------------------------ 1 file changed, 15 insertions(+), 50 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e2cf8ee..0050af3 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2706,62 +2706,27 @@ def get_binned_trials( return binned - def get_chance_data(self, time_bin, pos_bin): + def get_chance_positional_psd(self, units, label, event, sigma, end_event): streams = self.files["pixels"] + chance_psd = {} for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) - paths = { - "spiked_memmap_path": self.interim /\ - stream_files["spiked_shuffled_memmap"], - "fr_memmap_path": self.interim /\ - stream_files["fr_shuffled_memmap"], - "memmap_shape_path": self.interim /\ - stream_files["shuffled_shape"], - "idx_path": self.interim / stream_files["shuffled_index"], - "col_path": self.interim /\ - stream_files["shuffled_columns"], - } - - # TODO apr 3 2025: how the fuck to get positions here???? - # TEMP: get it manually... - # light - #pos_path = self.interim /\ - # "cache/align_trials_all_trial_times_725_1_100_512.h5" - # dark - pos_path = self.interim /\ - "cache/align_trials_all_trial_times_1322_1_100_512.h5" - - with pd.HDFStore(pos_path, "r") as store: - # list all keys - keys = store.keys() - # create df as a dictionary to hold all dfs - df = {} - # TODO apr 2 2025: for now the nested dict have keys in the - # format of `/imec0.ap/positions`, this will not be the case - # once i flatten files at the stream level rather than - # session level, i.e., every pixels related cache will have - # stream id in their name. - for key in keys: - # read current df - data = store[key] - # remove "/" in key - key_name = key.lstrip("/") - # use key name as dict key - df[key_name] = data - positions = df[f"{stream_id[:-3]}/positions"] - - xut.get_spike_chance( - sample_rate=self.SAMPLE_RATE, - positions=positions, - time_bin=time_bin, - pos_bin=pos_bin, - **paths, + chance_psd[stream_id] = stream.get_chance_positional_psd( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg ) - assert 0 - #spiked_chance = ioutils.read_hdf5(spiked_chance_path, key="spiked") - return None + return chance_psd def save_spike_chance(self, spiked, sigma, sample_rate): From d9469641a1cc3f654b4d17d2b2ea412b0ca2e530 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:36:19 +0100 Subject: [PATCH 487/658] use welch method by default for psd estimation --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index ddca528..fee81fa 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1372,7 +1372,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): def get_spatial_psd(pos_fr): def _compute_psd(col): x = col.dropna().values.squeeze() - f, psd = math_utils.estimate_power_spectrum(x) + f, psd = math_utils.estimate_power_spectrum(x, use_welch=True) # remove 0 to avoid infinity f = f[1:] psd = psd[1:] From 074bbc92ea4926235c5c105362548b8baf56e577 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:36:50 +0100 Subject: [PATCH 488/658] no need to get bin args since we will separate them --- pixels/pixels_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index fee81fa..7c7b306 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -996,24 +996,23 @@ def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, return None -def get_spike_chance(sample_rate, positions, time_bin, pos_bin, - spiked_memmap_path, fr_memmap_path, memmap_shape_path, - idx_path, col_path): +def get_spike_chance(sample_rate, positions, spiked_memmap_path, fr_memmap_path, + memmap_shape_path, idx_path, col_path): if not fr_memmap_path.exists(): raise PixelsError("\nHave you saved spike chance data yet?") else: # TODO apr 3 2025: we need to get positions here for binning!!! # BUT HOW???? - _get_spike_chance(sample_rate, positions, time_bin, pos_bin, - spiked_memmap_path, fr_memmap_path, memmap_shape_path, - idx_path, col_path) + fr_chance, idx, cols = _get_spike_chance( + sample_rate, positions, spiked_memmap_path, fr_memmap_path, + memmap_shape_path, idx_path, col_path) + return fr_chance, idx, cols return None -def _get_spike_chance(sample_rate, positions, time_bin, pos_bin, - spiked_memmap_path, fr_memmap_path, memmap_shape_path, - idx_path, col_path): +def _get_spike_chance(sample_rate, positions, spiked_memmap_path, + fr_memmap_path, memmap_shape_path, idx_path, col_path): # TODO apr 9 2025: # i do not need to save shape to file, all i need is unit count, repeat, @@ -1039,6 +1038,11 @@ def _get_spike_chance(sample_rate, positions, time_bin, pos_bin, col_df = read_hdf5(col_path, key="cols") cols = pd.Index(col_df["unit"]) + return fr_chance, idx, cols + assert 0 + + # TODO jun 16 2025: + # have a separate func for binning chance data binned_shuffle = {} temp = {} # TODO apr 3 2025: implement multiprocessing here! From 884017d917e6ff464d2c230173160253463dbfe1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:38:04 +0100 Subject: [PATCH 489/658] use minimum as start of index not 0 --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7c7b306..46e022b 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1259,7 +1259,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): # get constants from vd from vision_in_darkness.constants import TUNNEL_RESET # create position indices - indices_range = [0, TUNNEL_RESET+1] + indices_range = [np.floor(positions.min().min()), TUNNEL_RESET+1] # get trial ids trial_ids = positions.columns.get_level_values("trial") From fad145e882957558432e0430e2ff87a0fe3d929e Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:38:32 +0100 Subject: [PATCH 490/658] incorporate chance data indexing --- pixels/pixels_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 46e022b..f04b0bc 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1286,11 +1286,19 @@ def _get_vr_positional_neural_data(positions, data_type, data): trial_pos = trial_pos[trial_pos <= indices[-1]] # get firing rates for current trial of all units - trial_data = data.xs( - key=trial, - axis=1, - level="trial", - ).dropna(how="all").copy() + try: + trial_data = data.xs( + key=trial, + axis=1, + level="trial", + ).dropna(how="all").copy() + except TypeError: + # chance data has trial and time on index, not column + trial_data = data.T.xs( + key=trial, + axis=1, + level="trial", + ).T.dropna(how="all").copy() # get all indices before post reset no_post_reset = trial_data.index.intersection(trial_pos.index) From e9bc2860f62221e56dbf718847136c7e2ae031f8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:38:47 +0100 Subject: [PATCH 491/658] remove negative positions --- pixels/pixels_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index f04b0bc..4ee7e86 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1377,6 +1377,10 @@ def _get_vr_positional_neural_data(positions, data_type, data): ).dropna(how="all") occupancy = occupancy.dropna(how="all") + # remove negative position values + if occupancy.index.min() < 0: + occupancy = occupancy.loc[0:, :] + pos_data = pos_data.loc[0:, :] return pos_data, occupancy From 57cfc60c8319345d1978ea977651f1918969b28d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:39:08 +0100 Subject: [PATCH 492/658] add notes --- pixels/stream.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 42e922d..900779e 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -778,6 +778,10 @@ def get_spatial_psd( """ Get spatial power spectral density of selected units. """ + # NOTE: jun 19 2025 + # potentially we could use aligned trials directly for psd estimation, + # with trial position as x, fr as y, and use lomb-scargle method + # NOTE: order of args matters for loading the cache! # always put units first, cuz it is like that in # experiemnt.align_trials, otherwise the same cache cannot be loaded From a9361b53eeeb5307cf252b7744e5a5be52644957 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:39:29 +0100 Subject: [PATCH 493/658] get chance data --- pixels/stream.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 900779e..6870ada 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -829,3 +829,41 @@ def get_spatial_psd( # all nan in other starts, so remember to dropna(axis=1)! return psd_df + + + def get_spike_chance(self, units, label, event, sigma, end_event): + trials = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + ) + positions = trials["positions"] + + probe_id = self.stream_id[:-3] + name = self.session.name + paths = { + "spiked_memmap_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_spiked_shuffled.bin", + "fr_memmap_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_fr_shuffled.bin", + "memmap_shape_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_shuffled_shape.json", + "idx_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_shuffled_index.h5", + "col_path": self.interim/\ + f"{name}_{probe_id}_{label.name}_shuffled_columns.h5", + } + + fr_chance, idx, cols = xut.get_spike_chance( + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + positions=positions, + **paths, + ) + + return positions, fr_chance, idx, cols + + + From e98a75870e57ae06f2e4cedb3051b93bbc10120c Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 19 Jun 2025 20:39:37 +0100 Subject: [PATCH 494/658] get chance psd --- pixels/stream.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 6870ada..cffd44d 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -866,4 +866,48 @@ def get_spike_chance(self, units, label, event, sigma, end_event): return positions, fr_chance, idx, cols + @cacheable + def get_chance_positional_psd(self, units, label, event, sigma, end_event): + assert 0 + from vision_in_darkness.constants import PRE_DARK_LEN, landmarks + positions, fr_chance, idx, cols = self.get_spike_chance( + units, + label, + event, + sigma, + end_event, + ) + + psds = {} + for r in range(fr_chance.shape[-1]): + repeat = fr_chance[:, :, r] + fr = pd.DataFrame(repeat, index=idx, columns=cols) + + pos_fr, _ = xut._get_vr_positional_neural_data( + positions=positions, + data_type="spike_rate", + data=fr, + ) + + psds[r] = {} + starts = pos_fr.columns.get_level_values("start").unique() + for start in starts: + start_df = pos_fr.xs( + start, + level="start", + axis=1, + ).dropna(how="all") + + cropped = start_df.loc[start+PRE_DARK_LEN:landmarks[-1]-1, :] + psds[r][start] = xut.get_spatial_psd(cropped) + + psd_df = pd.concat( + psds[r], + names=["start","frequency"], + levels=["repeat", "unit", "trial"], + ) + psds[r] = psd_df + + return pd.concat(psds, axis=1) + From f6f8363e8bd0eb2f517c5ff28c11b895bfbbfaf3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:24:35 +0100 Subject: [PATCH 495/658] add defautl repeats for chance --- pixels/constants.py | 2 ++ pixels/pixels_utils.py | 1 + 2 files changed, 3 insertions(+) diff --git a/pixels/constants.py b/pixels/constants.py index 4c55af1..f5aa59d 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -13,3 +13,5 @@ BEHAVIOUR_HZ = 25000 np.random.seed(BEHAVIOUR_HZ) + +REPEATS = 100 diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4ee7e86..4cd88bc 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -20,6 +20,7 @@ from pixels.ioutils import write_hdf5, reindex_by_longest from pixels.error import PixelsError from pixels.configs import * +from pixels.constants import * from common_utils import math_utils from common_utils.file_utils import init_memmap, read_hdf5 From b44cb7a504d9bfdf04563d244b8709c112383515 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:25:28 +0100 Subject: [PATCH 496/658] use a separate func to get chance paths and positions to speed up multiprcocessing --- pixels/stream.py | 64 +++++++++++++++++------------------------------- 1 file changed, 23 insertions(+), 41 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index cffd44d..3f8a89e 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -832,6 +832,24 @@ def get_spatial_psd( def get_spike_chance(self, units, label, event, sigma, end_event): + positions, paths = self._get_chance_args( + units, + label, + event, + sigma, + end_event, + ) + + fr_chance, idx, cols = xut.get_spike_chance( + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + positions=positions, + **paths, + ) + + return positions, fr_chance, idx, cols + + + def _get_chance_args(self, units, label, event, sigma, end_event): trials = self.align_trials( units=units, # NOTE: ALWAYS the first arg data="trial_rate", # NOTE: ALWAYS the second arg @@ -857,20 +875,13 @@ def get_spike_chance(self, units, label, event, sigma, end_event): f"{name}_{probe_id}_{label.name}_shuffled_columns.h5", } - fr_chance, idx, cols = xut.get_spike_chance( - sample_rate=self.BEHAVIOUR_SAMPLE_RATE, - positions=positions, - **paths, - ) - - return positions, fr_chance, idx, cols + return positions, paths @cacheable def get_chance_positional_psd(self, units, label, event, sigma, end_event): - assert 0 from vision_in_darkness.constants import PRE_DARK_LEN, landmarks - positions, fr_chance, idx, cols = self.get_spike_chance( + positions, paths = self._get_chance_args( units, label, event, @@ -878,36 +889,7 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): end_event, ) - psds = {} - for r in range(fr_chance.shape[-1]): - repeat = fr_chance[:, :, r] - fr = pd.DataFrame(repeat, index=idx, columns=cols) - - pos_fr, _ = xut._get_vr_positional_neural_data( - positions=positions, - data_type="spike_rate", - data=fr, - ) - - psds[r] = {} - starts = pos_fr.columns.get_level_values("start").unique() - for start in starts: - start_df = pos_fr.xs( - start, - level="start", - axis=1, - ).dropna(how="all") - - cropped = start_df.loc[start+PRE_DARK_LEN:landmarks[-1]-1, :] - psds[r][start] = xut.get_spatial_psd(cropped) - - psd_df = pd.concat( - psds[r], - names=["start","frequency"], - levels=["repeat", "unit", "trial"], - ) - psds[r] = psd_df - - return pd.concat(psds, axis=1) - + logging.info("> getting chance psd") + psds = xut.save_chance_psd(self.BEHAVIOUR_SAMPLE_RATE, positions, paths) + return psds From a82ccf570d44b820bcdbb99ad8400de7d8f5314f Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:26:15 +0100 Subject: [PATCH 497/658] use sptial sampling frequency to get index step --- pixels/pixels_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4cd88bc..82e734b 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1247,6 +1247,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): shape: position x trial """ from pandas.api.types import is_integer_dtype + from vision_in_darkness.constants import SPATIAL_SAMPLE_RATE if "bin" in positions.index.name: logging.info(f"\n> Getting binned positional {data_type}...") @@ -1266,7 +1267,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): trial_ids = positions.columns.get_level_values("trial") # create position indices - indices = np.arange(*indices_range).astype(int) + indices = np.arange(*indices_range, SPATIAL_SAMPLE_RATE).astype(int) # create occupancy array for trials occupancy = pd.DataFrame( data=np.full((len(indices), positions.shape[1]), np.nan), From 13717ae6f5919d7da260b5d1ac23474937856794 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:26:40 +0100 Subject: [PATCH 498/658] use multiprocessing to get chance psd --- pixels/pixels_utils.py | 86 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 82e734b..520418e 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1402,3 +1402,89 @@ def _compute_psd(col): psd = pos_fr.apply(_compute_psd, axis=0) return psd + + +def _psd_chance_worker(r, sample_rate, positions, paths, cut_start, cut_end): + """ + Worker that computes one set of psd. + + params + === + i: index of current repeat. + + return + === + """ + logging.info(f"\nProcessing repeat {r}...") + chance_data, idx, cols = get_spike_chance( + sample_rate=sample_rate, + positions=positions, + **paths, + ) + + pos_fr, _ = _get_vr_positional_neural_data( + positions=positions, + data_type="spike_rate", + data=pd.DataFrame(chance_data[..., r], index=idx, columns=cols), + ) + del chance_data, positions + + psd = {} + starts = pos_fr.columns.get_level_values("start").unique() + + for start in starts: + start_df = pos_fr.xs( + start, + level="start", + axis=1, + ).dropna(how="all") + + cropped = start_df.loc[start+cut_start:cut_end, :] + psd[start] = get_spatial_psd(cropped) + + del cropped, start_df + + psd_df = pd.concat( + psd, + names=["start", "frequency"], + ) + + logging.info(f"\nRepeat {r} finished.") + + return psd_df + + +def save_chance_psd(sample_rate, positions, paths):#chance_data, idx, cols): + """ + Implementation of saving chance level spike data. + """ + #import concurrent.futures + from concurrent.futures import ProcessPoolExecutor, as_completed + from vision_in_darkness.constants import PRE_DARK_LEN, landmarks + + # Set up the process pool to run the worker in parallel. + # Submit jobs for each repeat. + futures = [] + with ProcessPoolExecutor() as executor: + for r in range(REPEATS): + future = executor.submit( + _psd_chance_worker, + r, + sample_rate, + positions, + paths, + PRE_DARK_LEN, + landmarks[-1] - 1 + ) + futures.append(future) + # collect and concat + results = [f.result() for f in as_completed(futures)] + + psds = pd.concat( + results, + axis=1, + keys=range(REPEATS), + names=["repeat", "unit", "trial"], + ) + + return psds From 6b5ab82700c5a60ea0bc6cc83668d250c8266ccf Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:27:16 +0100 Subject: [PATCH 499/658] add start level --- pixels/experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 8ff299d..5aae149 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -280,7 +280,7 @@ def align_trials(self, *args, units=None, **kwargs): positions = ioutils.get_aligned_data_across_sessions( trials=trials, key="positions", - level_names=["session", "stream", "trial"], + level_names=["session", "stream", "start", "trial"], ) df = { "fr": fr, @@ -562,7 +562,7 @@ def get_binned_trials(self, *args, units=None, **kwargs): if not result is None: trials[name] = result - level_names = ["session", "stream", "unit", "trial"] + level_names = ["session", "stream", "start", "unit", "trial"] bin_fr = ioutils.get_aligned_data_across_sessions( trials=trials, key="bin_fr", From 8e894f73ef99355755e3bbfe178b7e6853014b28 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:27:54 +0100 Subject: [PATCH 500/658] change key names --- pixels/experiment.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 5aae149..e7a3062 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -565,23 +565,23 @@ def get_binned_trials(self, *args, units=None, **kwargs): level_names = ["session", "stream", "start", "unit", "trial"] bin_fr = ioutils.get_aligned_data_across_sessions( trials=trials, - key="bin_fr", + key="pos_fr", level_names=level_names, ) bin_fc = ioutils.get_aligned_data_across_sessions( trials=trials, - key="bin_fc", + key="pos_fc", level_names=level_names, ) - bin_pos = ioutils.get_aligned_data_across_sessions( + bin_occupancies = ioutils.get_aligned_data_across_sessions( trials=trials, - key="bin_pos", - level_names=["session", "stream", "pos_type", "trial"], + key="occupancy", + level_names=["session", "stream", "trial"], ) df = { "bin_fr": bin_fr, "bin_fc": bin_fc, - "bin_pos": bin_pos, + "bin_occupancy": bin_occupancies, } return df From be46e2ce38bad8a7f18fdc528845f50a69322569 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:28:04 +0100 Subject: [PATCH 501/658] add todo --- pixels/experiment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index e7a3062..bf5ca40 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -545,6 +545,9 @@ def get_binned_trials(self, *args, units=None, **kwargs): Get binned firing rate and spike count for aligned vr trials. Check behaviours.base.Behaviour.get_binned_trials for usage information. """ + # TODO jun 21 2025: + # can we combine this func with get_positional_data since they are + # basically the same, we just need to add `use_binned` bool in the arg session_names = [session.name for session in self.sessions] trials = {} if not units is None: From 17b477a4117faeff0409fc91df80b5c28e732f88 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 21 Jun 2025 20:38:14 +0100 Subject: [PATCH 502/658] check if df is dict before checking if it is nested dict --- pixels/decorators.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 13b175e..17aa3cb 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -76,25 +76,26 @@ def wrapper(*args, **kwargs): logging.info("\n> df is None, cache will exist but be empty.") else: # allows to save multiple dfs in a dict in one hdf5 file - if ioutils.is_nested_dict(df): - for probe_id, nested_dict in df.items(): - # NOTE: we remove `.ap` in stream id cuz having `.`in - # the key name get problems - for name, values in nested_dict.items(): + if isinstance(df, dict): + if ioutils.is_nested_dict(df): + for probe_id, nested_dict in df.items(): + # NOTE: we remove `.ap` in stream id cuz having `.`in + # the key name get problems + for name, values in nested_dict.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=f"/{probe_id}/{name}", + mode="a", + ) + else: + for name, values in df.items(): ioutils.write_hdf5( path=cache_path, df=values, - key=f"/{probe_id}/{name}", + key=name, mode="a", ) - elif isinstance(df, dict): - for name, values in df.items(): - ioutils.write_hdf5( - path=cache_path, - df=values, - key=name, - mode="a", - ) else: ioutils.write_hdf5(cache_path, df) return df From 9866f7ca7583a1c068aba5d5a342598a882978b3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 23 Jun 2025 17:54:24 +0100 Subject: [PATCH 503/658] remove todo --- pixels/pixels_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 520418e..4c742b7 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1002,8 +1002,6 @@ def get_spike_chance(sample_rate, positions, spiked_memmap_path, fr_memmap_path, if not fr_memmap_path.exists(): raise PixelsError("\nHave you saved spike chance data yet?") else: - # TODO apr 3 2025: we need to get positions here for binning!!! - # BUT HOW???? fr_chance, idx, cols = _get_spike_chance( sample_rate, positions, spiked_memmap_path, fr_memmap_path, memmap_shape_path, idx_path, col_path) From 48c9f953be7066e2220737b4f40ba88dc4bc9294 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 23 Jun 2025 17:54:37 +0100 Subject: [PATCH 504/658] if using all units, columns (unit ids) are the same across conditions --- pixels/stream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 3f8a89e..d718df7 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -871,8 +871,9 @@ def _get_chance_args(self, units, label, event, sigma, end_event): f"{name}_{probe_id}_{label.name}_shuffled_shape.json", "idx_path": self.interim/\ f"{name}_{probe_id}_{label.name}_shuffled_index.h5", + # NOTE: if all units, all conditions share the same columns "col_path": self.interim/\ - f"{name}_{probe_id}_{label.name}_shuffled_columns.h5", + self.files["shuffled_columns"], } return positions, paths From 773b2e60b0fd7ba35401dccc81560897612e9c16 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 28 Jul 2025 11:51:16 +0100 Subject: [PATCH 505/658] allow to normalise before getting positional data --- pixels/behaviours/base.py | 3 +++ pixels/stream.py | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 0050af3..539dd27 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2628,6 +2628,7 @@ def get_aligned_spike_rate_CI( def get_positional_data( self, label, event, end_event=None, sigma=None, units=None, + normalised=False, ): """ Get positional firing rate of selected units in vr, and spatial @@ -2651,12 +2652,14 @@ def get_positional_data( f"\n> Getting positional neural data of {units} units in " f"<{label.name}> trials." ) + output[stream_id] = stream.get_positional_data( units=units, # NOTE: put units first! label=label, event=event, sigma=sigma, end_event=end_event, + normalised=normalised, ) return output diff --git a/pixels/stream.py b/pixels/stream.py index d718df7..c557f11 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -604,6 +604,7 @@ def _extract_binned_data(self, binned_data, pos_cols): @cacheable def get_positional_data( self, label, event, end_event=None, sigma=None, units=None, + normalised=False, ): """ Get positional firing rate of selected units in vr, and spatial @@ -623,6 +624,45 @@ def get_positional_data( end_event=end_event, # NOTE: ALWAYS the last arg ) + if normalised: + grays = self.align_trials( + units=units, # NOTE: ALWAYS the first arg + data="trial_rate", # NOTE: ALWAYS the second arg + label=getattr(label, label.name.split("_")[-1]), + event=event.gray_on, + sigma=sigma, + end_event=end_event.gray_off, # NOTE: ALWAYS the last arg + ) + + # NOTE july 24 2025: if get gray mu & sigma per trial we got z score + # of very quiet & stable in gray units >9e+15... thus, we normalise + # average all trials for each unit, rather than per trial + + # NOTE: 500ms of 500Hz sine wave sound at each trial start, 2000ms + # of gray, so only take the second 1000ms in gray to get mean and + # std + + # only select trials exists in aligned trials + baseline = grays["fr"].iloc[ + self.BEHAVIOUR_SAMPLE_RATE: self.BEHAVIOUR_SAMPLE_RATE * 2 + ].loc[:, trials["fr"].columns].T.groupby("unit").mean().T + + #baseline_log = np.log1p(baseline) + #mu_log = baseline_log.mean() + #std_log = baseline_log.std() + #fr_log = np.log1p(fr) + #z_fr_log = (fr_log.sub(mu_log, axis=1, level="unit") + # .div(std_log, axis=1, level="unit")) + + mu = baseline.mean() + centered = trials["fr"].sub(mu, axis=1, level="unit") + std = baseline.std() + z_fr = centered.div(std, axis=1, level="unit") + trials["fr"] = z_fr + + del grays, baseline, mu, std, z_fr + gc.collect() + # get positional spike rate, spike count, and occupancy positional_data = xut.get_vr_positional_data(trials) From b735278db96d80ef0b929c8488a0444cc52877be Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 28 Jul 2025 11:51:42 +0100 Subject: [PATCH 506/658] do not cache get_positional_data --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index c557f11..9d66f06 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -601,7 +601,7 @@ def _extract_binned_data(self, binned_data, pos_cols): return data, pos - @cacheable + #@cacheable def get_positional_data( self, label, event, end_event=None, sigma=None, units=None, normalised=False, From e600ac385d1e1bbb4a191e4cb3821e6d8d6d91f7 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 28 Jul 2025 11:52:50 +0100 Subject: [PATCH 507/658] no need to import Events again, get attr from `event` --- pixels/stream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 9d66f06..1445e80 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -139,9 +139,8 @@ def _get_aligned_trials( # map actual starting locations if not "trial_start" in event.name: - from pixels.behaviours.virtual_reality import Events all_start_idx = np.where( - np.bitwise_and(events, Events.trial_start) + np.bitwise_and(events, event.trial_start) )[0] start_idx = trials[np.where( np.isin(trials, all_start_idx) From b26cb46b2258a65b5c91f595b47cae69f2786bae Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 28 Jul 2025 11:53:26 +0100 Subject: [PATCH 508/658] make sure to include last frame of the event during alignment --- pixels/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 1445e80..45cd474 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -162,7 +162,7 @@ def _get_aligned_trials( # convolution scan_pad = self.BEHAVIOUR_SAMPLE_RATE scan_starts = start_t - scan_pad - scan_ends = end_t + scan_pad + scan_ends = end_t + scan_pad + 1 scan_durations = scan_ends - scan_starts cursor = 0 @@ -190,7 +190,7 @@ def _get_aligned_trials( trial = rec_spikes[trial_bool] # get position bin ids for current trial trial_pos_bool = (all_pos.index >= start_t[i])\ - & (all_pos.index < end_t[i]) + & (all_pos.index <= end_t[i]) trial_pos = all_pos[trial_pos_bool] # initiate binary spike times array for current trial From f5eb4b7b2998beb2273985e2788bd3dccded2695 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 28 Jul 2025 11:54:07 +0100 Subject: [PATCH 509/658] delete used vars to save memory --- pixels/stream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 45cd474..39176aa 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1,3 +1,5 @@ +import gc + import numpy as np import pandas as pd From c2ec1b19fe24305a9dcb861ac4c379526f0a79b8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 28 Jul 2025 11:54:38 +0100 Subject: [PATCH 510/658] no need to do log transformation --- pixels/stream.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 39176aa..1ab6627 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -648,13 +648,6 @@ def get_positional_data( self.BEHAVIOUR_SAMPLE_RATE: self.BEHAVIOUR_SAMPLE_RATE * 2 ].loc[:, trials["fr"].columns].T.groupby("unit").mean().T - #baseline_log = np.log1p(baseline) - #mu_log = baseline_log.mean() - #std_log = baseline_log.std() - #fr_log = np.log1p(fr) - #z_fr_log = (fr_log.sub(mu_log, axis=1, level="unit") - # .div(std_log, axis=1, level="unit")) - mu = baseline.mean() centered = trials["fr"].sub(mu, axis=1, level="unit") std = baseline.std() From 5e1c54a851fa207a2883785600e1c0e7c19183db Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Jul 2025 17:38:41 +0100 Subject: [PATCH 511/658] set index to times index too --- pixels/signal_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pixels/signal_utils.py b/pixels/signal_utils.py index a585f87..d421638 100644 --- a/pixels/signal_utils.py +++ b/pixels/signal_utils.py @@ -337,7 +337,11 @@ def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): axis=0, ) * sample_rate # rescale it to second - output = pd.DataFrame(convolved, columns=times.columns) + output = pd.DataFrame( + convolved, + columns=times.columns, + index=times.index, + ) elif isinstance(times, np.ndarray): # convolve with gaussian From b2c4332304028b0e4a97bf33e38ce24307d7efd3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Jul 2025 17:57:00 +0100 Subject: [PATCH 512/658] no need to pass data --- pixels/stream.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 1ab6627..44d0491 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -88,8 +88,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): f"<{label.name}> trials." ) return self._get_aligned_trials( - label, event, data=data, units=units, sigma=sigma, - end_event=end_event, + label, event, units=units, sigma=sigma, end_event=end_event, ) else: raise NotImplementedError( @@ -97,7 +96,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): ) def _get_aligned_trials( - self, label, event, data, units=None, sigma=None, end_event=None, + self, label, event, units=None, sigma=None, end_event=None, ): # get synched pixels stream with vr and action labels synched_vr, action_labels = self.get_synched_vr() From 7f4875b74e87578d23f6fd9884e36775fe3d3521 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Jul 2025 18:01:16 +0100 Subject: [PATCH 513/658] implement event alignment --- pixels/stream.py | 197 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 44d0491..efea2e5 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -90,11 +90,20 @@ def align_trials(self, units, data, label, event, sigma, end_event): return self._get_aligned_trials( label, event, units=units, sigma=sigma, end_event=end_event, ) + elif data in ("spike_rate", "spike_times"): + logging.info( + f"\n> Aligning spike times and spike rate of {units} units to " + f"{event} event in <{label.name}> trials." + ) + return self._get_aligned_spike_times( + label, event, units=units, sigma=sigma, + ) else: raise NotImplementedError( "> Other types of alignment are not implemented." ) + def _get_aligned_trials( self, label, event, units=None, sigma=None, end_event=None, ): @@ -298,6 +307,194 @@ def _get_aligned_trials( return output + def _get_aligned_spike_times(self, label, event, units=None, sigma=None): + # get synched pixels stream with vr and action labels + synched_vr, action_labels = self.get_synched_vr() + + # get positions of all trials + all_pos = synched_vr.position_in_tunnel + + # get spike times + spikes = self.get_spike_times(units) + + # get action and event label file + outcomes = action_labels["outcome"] + events = action_labels["events"] + # get timestamps index of behaviour in self.BEHAVIOUR_SAMPLE_RATE hz, to + # convert it to ms, do timestamps*1000/self.BEHAVIOUR_SAMPLE_RATE + timestamps = action_labels["timestamps"] + + # select frames of wanted trial type + trials = np.where(np.bitwise_and(outcomes, label))[0] + # map starts by event + starts = np.where(np.bitwise_and(events, event))[0] + + # only take starts from selected trials + selected_starts = trials[np.where(np.isin(trials, starts))[0]] + start_t = timestamps[selected_starts] + + if selected_starts.size == 0: + logging.info(f"\n> No trials found with label {label} and event " + f"{event.name}, output will be empty.") + return None + + # use original trial id as trial index + trial_ids = pd.Index( + synched_vr.iloc[selected_starts].trial_count.unique() + ) + + # map actual starting locations + if not "trial_start" in event.name: + all_start_idx = np.where( + np.bitwise_and(events, event.trial_start) + )[0] + start_idx = trials[np.where( + np.isin(trials, all_start_idx) + )[0]] + else: + start_idx = selected_starts.copy() + + start_pos = synched_vr.position_in_tunnel.iloc[ + start_idx + ].values.astype(int) + + # create multiindex with starts + cols_with_starts = pd.MultiIndex.from_arrays( + [start_pos, trial_ids], + names=("start", "trial"), + ) + + # pad ends with 1 second extra to remove edge effects from convolution, + # during of event is 2s + duration = 2 + pad_duration = 1 + scan_pad = self.BEHAVIOUR_SAMPLE_RATE + one_side_frames = scan_pad * (duration + pad_duration) + scan_starts = start_t - one_side_frames + scan_ends = start_t + one_side_frames + 1 + scan_duration = one_side_frames * 2 + 1 + relative_idx = np.linspace( + -(duration+pad_duration), + (duration+pad_duration), + scan_duration, + ) + + cursor = 0 + raw_rec = self.load_raw_ap() + samples = raw_rec.get_total_samples() + # Account for multiple raw data files + in_SAMPLE_RATE_scale = (samples * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + cursor_duration = (cursor * self.BEHAVIOUR_SAMPLE_RATE)\ + / raw_rec.sampling_frequency + rec_spikes = spikes[ + (cursor_duration <= spikes)\ + & (spikes < (cursor_duration + in_SAMPLE_RATE_scale)) + ] - cursor_duration + cursor += samples + + output = {} + trials_fr = {} + trials_spiked = {} + trials_positions = {} + for i, start in enumerate(selected_starts): + # select spike times of event in current trial + trial_bool = (rec_spikes >= scan_starts[i])\ + & (rec_spikes <= scan_ends[i]) + trial = rec_spikes[trial_bool] + # get position bin ids for current trial + #trial_pos_bool = (all_pos.index >= start_t[i])\ + # & (all_pos.index <= end_t[i]) + #trial_pos = all_pos[trial_pos_bool] + + # initiate binary spike times array for current trial + # NOTE: dtype must be float otherwise would get all 0 when passing + # gaussian kernel + times = np.zeros((scan_duration, len(units))).astype(float) + # use pixels time as spike index + idx = np.arange(scan_starts[i], scan_ends[i]) + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) + + for unit in trial: + # get spike time for unit + u_times = trial[unit].values + # drop nan + u_times = u_times[~np.isnan(u_times)] + # round spike times to use it as index + u_spike_idx = np.round(u_times).astype(int) + # make sure it does not exceed scan duration + if (u_spike_idx >= scan_ends[i]).any(): + beyonds = np.where(u_spike_idx >= scan_ends[i])[0] + u_spike_idx[beyonds] = idx[-1] + # make sure no double counted + u_spike_idx = np.unique(u_spike_idx) + + # set spiked to 1 + spiked.loc[u_spike_idx, unit] = 1 + + # set spiked index to relative index + spiked.index = relative_idx + # convolve spike trains into spike rates + rates = signal.convolve_spike_trains( + times=spiked, + sigma=sigma, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + ) + + # remove 1s padding from the start and end + rates = rates.iloc[scan_pad: -scan_pad] + spiked = spiked.iloc[scan_pad: -scan_pad] + + trials_fr[trial_ids[i]] = rates + trials_spiked[trial_ids[i]] = spiked + + # get trials vertically stacked spiked + stacked_spiked = pd.concat( + trials_spiked, + axis=0, + ) + stacked_spiked.index.names = ["trial", "time"] + stacked_spiked.columns.names = ["unit"] + + # TODO apr 21 2025: + # save spike chance only if all units are selected, else + # only index into the big chance array and save into zarr + #if units.name == "all" and (label == 725 or 1322): + # self.save_spike_chance( + # stream_files=stream_files, + # spiked=stacked_spiked, + # sigma=sigma, + # ) + #else: + # # access chance data if we only need part of the units + # self.get_spike_chance( + # sample_rate=self.SAMPLE_RATE, + # positions=all_pos, + # sigma=sigma, + # ) + # assert 0 + + # get trials horizontally stacked spiked + spiked = ioutils.reindex_by_longest( + dfs=stacked_spiked, + level="trial", + return_format="dataframe", + ) + fr = ioutils.reindex_by_longest( + dfs=trials_fr, + level="trial", + idx_names=["trial", "time"], + col_names=["unit"], + return_format="dataframe", + ) + + output["spiked"] = spiked + output["fr"] = fr + + return output + + def get_spike_times(self, units): # find sorting analyser path sa_path = self.session.find_file(self.files["sorting_analyser"]) From 020b8136693be7f1daf41cb6d6bded22c59488d4 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Jul 2025 18:03:38 +0100 Subject: [PATCH 514/658] use more informative names --- pixels/stream.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index efea2e5..80c1dd9 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -82,7 +82,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): df, output from individual functions according to data type. """ - if "trial_rate" in data: + if "trial" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " f"<{label.name}> trials." @@ -90,12 +90,12 @@ def align_trials(self, units, data, label, event, sigma, end_event): return self._get_aligned_trials( label, event, units=units, sigma=sigma, end_event=end_event, ) - elif data in ("spike_rate", "spike_times"): + elif "event" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " f"{event} event in <{label.name}> trials." ) - return self._get_aligned_spike_times( + return self._get_aligned_events( label, event, units=units, sigma=sigma, ) else: @@ -307,7 +307,7 @@ def _get_aligned_trials( return output - def _get_aligned_spike_times(self, label, event, units=None, sigma=None): + def _get_aligned_events(self, label, event, units=None, sigma=None): # get synched pixels stream with vr and action labels synched_vr, action_labels = self.get_synched_vr() From 41af6f6bc2a470470a87400bc8e1356d2496819f Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Jul 2025 18:03:47 +0100 Subject: [PATCH 515/658] no need to get positions --- pixels/stream.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 80c1dd9..112c144 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -402,10 +402,6 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): trial_bool = (rec_spikes >= scan_starts[i])\ & (rec_spikes <= scan_ends[i]) trial = rec_spikes[trial_bool] - # get position bin ids for current trial - #trial_pos_bool = (all_pos.index >= start_t[i])\ - # & (all_pos.index <= end_t[i]) - #trial_pos = all_pos[trial_pos_bool] # initiate binary spike times array for current trial # NOTE: dtype must be float otherwise would get all 0 when passing From 068c55015d401ef49e53401d15b232de655d23af Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 29 Jul 2025 18:04:02 +0100 Subject: [PATCH 516/658] use align_trials implementation in stream --- pixels/behaviours/base.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 539dd27..97be774 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1953,24 +1953,27 @@ def align_trials( if data not in data_options: raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") - if data in ("spike_times", "spike_rate"): - logging.info(f"\nAligning {data} to trials.") - # we let a dedicated function handle aligning spike times - return self._get_aligned_spike_times( - label, event, duration, rate=data == "spike_rate", sigma=sigma, - units=units, - ) + streams = self.files["pixels"] + output = {} - if "trial" in data: - logging.info( - f"\n> Aligning {self.name} {data} of {units} units to label " - f"<{label.name}> trials." + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, ) - return self._get_aligned_trials( - label, event, data=data, units=units, sigma=sigma, + output[stream_id] = stream.align_trials( + units=units, # NOTE: ALWAYS the first arg + data=data, # NOTE: ALWAYS the second arg + label=label, + event=event, + sigma=sigma, end_event=end_event, ) + return output + if data == "motion_tracking" and not dlc_project: raise PixelsError("When aligning to 'motion_tracking', dlc_project is needed.") From 2897bb6fd103a729b14750d8cf2ae94a1e96c966 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:32:28 +0100 Subject: [PATCH 517/658] do not copy raw to interim --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 112c144..6fa422f 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -43,7 +43,7 @@ def __repr__(self): def load_raw_ap(self): - paths = [self.session.find_file(path) for path in self.files["ap_raw"]] + paths = [self.session.find_file(path, copy=False) for path in self.files["ap_raw"]] self.files["si_rec"] = xut.load_raw(paths, self.stream_id) return self.files["si_rec"] From 2746643a7e80c546f6572a530e1fa313c8eaab70 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:32:55 +0100 Subject: [PATCH 518/658] print name not event attr --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 6fa422f..8d06b81 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -93,7 +93,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): elif "event" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " - f"{event} event in <{label.name}> trials." + f"{event.name} event in <{label.name}> trials." ) return self._get_aligned_events( label, event, units=units, sigma=sigma, From 018bb08d7f4c0a59298065113d6a6bfb4b511687 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:33:20 +0100 Subject: [PATCH 519/658] reduce duration of event to 1s each side --- pixels/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 8d06b81..c57df62 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -365,8 +365,8 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): ) # pad ends with 1 second extra to remove edge effects from convolution, - # during of event is 2s - duration = 2 + # during of event is 2s (pre + post) + duration = 1 pad_duration = 1 scan_pad = self.BEHAVIOUR_SAMPLE_RATE one_side_frames = scan_pad * (duration + pad_duration) From 5fa432e02fef0e66458d92d01568b993dea72357 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:34:18 +0100 Subject: [PATCH 520/658] add starting position to spiked and fr --- pixels/stream.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index c57df62..2656084 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -358,11 +358,8 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): start_idx ].values.astype(int) - # create multiindex with starts - cols_with_starts = pd.MultiIndex.from_arrays( - [start_pos, trial_ids], - names=("start", "trial"), - ) + # map starting position with trial + start_trial_maps = dict(zip(trial_ids, start_pos)) # pad ends with 1 second extra to remove edge effects from convolution, # during of event is 2s (pre + post) @@ -477,6 +474,20 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): level="trial", return_format="dataframe", ) + trial_cols = spiked.columns.get_level_values("trial") + start_cols = trial_cols.map(start_trial_maps) + unit_cols = spiked.columns.get_level_values("unit") + new_cols = pd.MultiIndex.from_arrays( + [unit_cols, start_cols, trial_cols], + names=["unit", "start", "trial"] + ) + spiked.columns = new_cols + spiked = spiked.sort_index( + axis=1, + level=["unit", "start", "trial"], + ascending=[True, False, True], + ) + fr = ioutils.reindex_by_longest( dfs=trials_fr, level="trial", @@ -484,6 +495,8 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): col_names=["unit"], return_format="dataframe", ) + fr.columns = new_cols + fr = fr.loc[:, spiked.columns] output["spiked"] = spiked output["fr"] = fr From 6ae31f15253d61224eda83eacc63211754088d21 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:35:09 +0100 Subject: [PATCH 521/658] put spike in name just to differentiate with lfp --- pixels/behaviours/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 97be774..01b4f19 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1946,9 +1946,12 @@ def align_trials( 'lfp', # Raw/downsampled channels from probe (LFP) 'motion_index', # Motion index per ROI from the video 'motion_tracking', # Motion tracking coordinates from DLC - 'trial_rate', # Taking spike times from the whole duration of each + 'spike_trial', # Taking spike times from the whole duration of each # trial, convolve into spike rate, output also # contains times. + 'spike_event', # Taking spike times from +/- 2s of an event, + # convolve into spike rate, output also + # contains times. ] if data not in data_options: raise PixelsError(f"align_trials: 'data' should be one of: {data_options}") From bde43b5a9ee0d0eb64871f90d604bdf9220a357a Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:37:18 +0100 Subject: [PATCH 522/658] moved implementation to stream.py --- pixels/behaviours/base.py | 133 -------------------------------------- 1 file changed, 133 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 01b4f19..9bbebe4 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1554,139 +1554,6 @@ def get_spike_times(self, units, remapped=False, use_si=False): return spike_times[0] - def _get_aligned_spike_times( - self, label, event, duration, rate=False, sigma=None, units=None - ): - """ - Returns spike times for each unit within a given time window around an event. - align_trials delegates to this function, and should be used for getting aligned - data in scripts. - """ - action_labels = self.get_action_labels()[0] - - if units is None: - units = self.select_units() - - #TODO: with multiple streams, spike times will be a list with multiple dfs, - #make sure old code does not break! - #TODO: spike times cannot be indexed by unit ids anymore - spikes = self.get_spike_times()[units] - - if rate: - # pad ends with 1 second extra to remove edge effects from convolution - duration += 2 - - scan_duration = self.SAMPLE_RATE * 8 - half = int((self.SAMPLE_RATE * duration) / 2) - cursor = 0 # In sample points - i = -1 - rec_trials = {} - - for rec_num in range(len(self.files)): - actions = action_labels[rec_num][:, 0] - events = action_labels[rec_num][:, 1] - trial_starts = np.where(np.bitwise_and(actions, label))[0] - - # Account for multiple raw data files - meta = self.spike_meta[rec_num] - samples = int(meta["fileSizeBytes"]) / int(meta["nSavedChans"]) / 2 - assert samples.is_integer() - milliseconds = samples / 30 - cursor_duration = cursor / 30 - rec_spikes = spikes[ - (cursor_duration <= spikes) & (spikes < (cursor_duration + milliseconds)) - ] - cursor_duration - cursor += samples - - # Account for lag, in case the ephys recording was started before the - # behaviour - lag_start, _ = self._lag[rec_num] - if lag_start < 0: - rec_spikes = rec_spikes + lag_start - - for i, start in enumerate(trial_starts, start=i + 1): - centre = np.where(np.bitwise_and(events[start:start + scan_duration], event))[0] - if len(centre) == 0: - # See comment in align_trials as to why we just continue instead of - # erroring like we used to here. - logging.info("\nNo event found for an action. If this is OK, " - "ignore this.") - continue - centre = start + centre[0] - - trial = rec_spikes[centre - half < rec_spikes] - trial = trial[trial <= centre + half] - trial = trial - centre - tdf = [] - - for unit in trial: - u_times = trial[unit].values - u_times = u_times[~np.isnan(u_times)] - u_times = np.unique(u_times) # remove double-counted spikes - udf = pd.DataFrame({int(unit): u_times}) - tdf.append(udf) - - assert len(tdf) == len(units) - if tdf: - tdfc = pd.concat(tdf, axis=1) - if rate: - tdfc = signal.convolve(tdfc, duration * 1000, sigma) - rec_trials[i] = tdfc - - if not rec_trials: - return None - - trials = pd.concat(rec_trials, axis=1, names=["trial", "unit"]) - trials = trials.reorder_levels(["unit", "trial"], axis=1) - trials = trials.sort_index(level=0, axis=1) - - if rate: - # Set index to seconds and remove the padding 1 sec at each end - points = trials.shape[0] - start = (- duration / 2) + (duration / points) - # Having trouble with float values - #timepoints = np.linspace(start, duration / 2, points, dtype=np.float64) - timepoints = list(range(round(start * 1000), int(duration * 1000 / 2) + 1)) - trials['time'] = pd.Series(timepoints, index=trials.index) / 1000 - trials = trials.set_index('time') - trials = trials.iloc[self.SAMPLE_RATE : - self.SAMPLE_RATE] - - return trials - - - def _get_aligned_trials( - self, label, event, data, units=None, sigma=None, end_event=None, - ): - """ - Returns spike rate for each unit within a trial. - align_trials delegates to this function, and should be used for getting aligned - data in scripts. - - This function also saves binned data in the format that Alfredo wants: - trials * units * temporal bins (100ms) - - """ - streams = self.files["pixels"] - output = {} - - for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - stream = Stream( - stream_id=stream_id, - stream_num=stream_num, - files=stream_files, - session=self, - ) - output[stream_id] = stream.align_trials( - units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, - ) - - return output - def select_units( self, group='good', min_depth=0, max_depth=None, min_spike_width=None, From 8876efcb20a63c679e5ceb9487d1dd1952e609bb Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 31 Jul 2025 12:39:01 +0100 Subject: [PATCH 523/658] use cacheable decor in stream.py --- pixels/behaviours/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 9bbebe4..66c2f64 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1500,6 +1500,7 @@ def _get_si_spike_times(self, units): return spike_times[0] # NOTE: only deal with one stream for now + def get_spike_times(self, units, remapped=False, use_si=False): """ Returns the sorted spike times. @@ -1752,7 +1753,7 @@ def get_lfp_data_raw(self): """ return self._get_neuro_raw('lfp') - #@_cacheable + def align_trials( self, label, event, units=None, data='spike_times', raw=False, duration=1, sigma=None, dlc_project=None, video_match=None, From 557d07de834072824138e145d45924d26287a840 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 8 Aug 2025 16:08:06 +0100 Subject: [PATCH 524/658] add option to allow band extraction without preprocessing --- pixels/behaviours/base.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 66c2f64..df04534 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -636,14 +636,11 @@ def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): return None - def extract_bands(self, freqs=None): + def extract_bands(self, freqs=None, preprocess=True): """ extract data of ap and lfp frequency bands from the raw neural recording data. """ - # preprocess raw - self.preprocess_raw() - streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): logging.info( @@ -656,7 +653,7 @@ def extract_bands(self, freqs=None): files=stream_files, session=self, ) - stream.extract_bands(freqs) + stream.extract_bands(freqs, preprocess) """ if self._lag[rec_num] is None: From 2e14c9b9a7db66aec126d9662ef52a5f84c025b1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 8 Aug 2025 16:08:28 +0100 Subject: [PATCH 525/658] default to spike_event; remove old doc --- pixels/behaviours/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index df04534..a1e9928 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1752,7 +1752,7 @@ def get_lfp_data_raw(self): def align_trials( - self, label, event, units=None, data='spike_times', raw=False, + self, label, event, units=None, data='spike_event', raw=False, duration=1, sigma=None, dlc_project=None, video_match=None, end_event=None, ): @@ -1806,8 +1806,6 @@ def align_trials( data_options = [ 'behavioural', # Channels from behaviour TDMS file 'spike', # Raw/downsampled channels from probe (AP) - 'spike_times', # List of spike times per unit - 'spike_rate', # Spike rate signals from convolved spike times 'lfp', # Raw/downsampled channels from probe (LFP) 'motion_index', # Motion index per ROI from the video 'motion_tracking', # Motion tracking coordinates from DLC From 1c59585f409d677d34f2e7c58a7339e5ee6c7eb3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 11 Aug 2025 18:16:21 +0100 Subject: [PATCH 526/658] count cpu cores to set workers and multiprocessing context --- pixels/configs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pixels/configs.py b/pixels/configs.py index 17dec3b..f4fd2db 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -1,3 +1,5 @@ +import os + import logging from wavpack_numcodecs import WavPack @@ -15,11 +17,15 @@ #logging.warning('This is a warning message.') #logging.error('This is an error message.') +n_cores = os.cpu_count() # set si job_kwargs job_kwargs = dict( - n_jobs=0.8, # 80% core + mp_context="fork", # linux chunk_duration='1s', progress_bar=True, + #n_jobs=0.8, # 80% core + n_jobs=int(n_cores/4), # less worker + max_threads_per_worker=8, # but more thread each worker ) si.set_global_job_kwargs(**job_kwargs) From 3ff45f56d60f185308a73327212b6f729266d60b Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 11 Aug 2025 18:24:18 +0100 Subject: [PATCH 527/658] change name so that it is more data type specific --- pixels/stream.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 2656084..153c9f5 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -82,7 +82,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): df, output from individual functions according to data type. """ - if "trial" in data: + if "spike_trial" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " f"<{label.name}> trials." @@ -90,7 +90,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): return self._get_aligned_trials( label, event, units=units, sigma=sigma, end_event=end_event, ) - elif "event" in data: + elif "spike_event" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " f"{event.name} event in <{label.name}> trials." @@ -678,7 +678,7 @@ def _bin_aligned_trials( # get aligned trials trials = self.align_trials( units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg + data="spike_trial", # NOTE: ALWAYS the second arg label=label, event=event, sigma=sigma, @@ -823,7 +823,7 @@ def get_positional_data( # get aligned firing rates and positions trials = self.align_trials( units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg + data="spike_trial", # NOTE: ALWAYS the second arg label=label, event=event, sigma=sigma, @@ -833,7 +833,7 @@ def get_positional_data( if normalised: grays = self.align_trials( units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg + data="spike_trial", # NOTE: ALWAYS the second arg label=getattr(label, label.name.split("_")[-1]), event=event.gray_on, sigma=sigma, @@ -1091,7 +1091,7 @@ def get_spike_chance(self, units, label, event, sigma, end_event): def _get_chance_args(self, units, label, event, sigma, end_event): trials = self.align_trials( units=units, # NOTE: ALWAYS the first arg - data="trial_rate", # NOTE: ALWAYS the second arg + data="spike_trial", # NOTE: ALWAYS the second arg label=label, event=event, sigma=sigma, From f56581416531a9ea56cdc0d947e9b4d1b2bc6846 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 11 Aug 2025 18:25:22 +0100 Subject: [PATCH 528/658] allows to extract unpreprocessed bands --- pixels/stream.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 153c9f5..6790da4 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -887,8 +887,12 @@ def preprocess_raw(self): return None - def extract_bands(self, freqs): - self.preprocess_raw() + def extract_bands(self, freqs, preprocess): + if preprocess: + self.preprocess_raw() + rec = self.files["preprocessed"] + else: + rec = self.load_raw_ap() if freqs == None: bands = freq_bands @@ -903,7 +907,7 @@ def extract_bands(self, freqs): ) # do bandpass filtering self.files[f"{name}_extracted"] = xut.extract_band( - self.files["preprocessed"], + rec, freq_min=freqs[0], freq_max=freqs[1], ) From 812a21476da8767e04d1d93e3f003799bc321408 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 11 Aug 2025 18:25:41 +0100 Subject: [PATCH 529/658] add todo --- pixels/stream.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 6790da4..6874bba 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -343,6 +343,13 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): synched_vr.iloc[selected_starts].trial_count.unique() ) + # TODO aug 1 2025: + # lick happens more than once in a trial, thus i here does not + # correspond to trial index, fit it + # check if event happens more than once in each trial + if start_t.size > trial_ids.size: + trial_counts = synched_vr.loc[start_t, "trial_count"] + # map actual starting locations if not "trial_start" in event.name: all_start_idx = np.where( @@ -395,6 +402,7 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): trials_spiked = {} trials_positions = {} for i, start in enumerate(selected_starts): + assert 0 # select spike times of event in current trial trial_bool = (rec_spikes >= scan_starts[i])\ & (rec_spikes <= scan_ends[i]) From 5683886817e35ff6165c53db0ab96d19b66dc54b Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 11 Aug 2025 19:05:38 +0100 Subject: [PATCH 530/658] set workers to 4 cuz more than this it would not work --- pixels/configs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pixels/configs.py b/pixels/configs.py index f4fd2db..4cdc436 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -1,5 +1,3 @@ -import os - import logging from wavpack_numcodecs import WavPack @@ -17,14 +15,12 @@ #logging.warning('This is a warning message.') #logging.error('This is an error message.') -n_cores = os.cpu_count() # set si job_kwargs job_kwargs = dict( mp_context="fork", # linux chunk_duration='1s', progress_bar=True, - #n_jobs=0.8, # 80% core - n_jobs=int(n_cores/4), # less worker + n_jobs=4, # less worker max_threads_per_worker=8, # but more thread each worker ) si.set_global_job_kwargs(**job_kwargs) From bf74cc607fe0ee89b175e5e68dc63be61cbf7716 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 12 Aug 2025 16:52:14 +0100 Subject: [PATCH 531/658] use spawn to avoid brokenprocesspool error --- pixels/configs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/configs.py b/pixels/configs.py index 4cdc436..98999a2 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -17,11 +17,12 @@ # set si job_kwargs job_kwargs = dict( - mp_context="fork", # linux + #mp_context="fork", # linux, but does not work still on 2025 aug 12 + mp_context="spawn", # mac & win chunk_duration='1s', progress_bar=True, - n_jobs=4, # less worker - max_threads_per_worker=8, # but more thread each worker + n_jobs=0.8, + max_threads_per_worker=1, ) si.set_global_job_kwargs(**job_kwargs) From 31b5c794e7dfbea6125e626f0330102c6ebc931c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 13 Aug 2025 14:20:40 +0100 Subject: [PATCH 532/658] use CAR not CMR cuz all bad channels are removed, and CAR preserves LFP better --- pixels/pixels_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4c742b7..9c86b1d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -113,13 +113,14 @@ def _preprocess_raw(rec, surface_depth): rec_clean = rec_removed.remove_channels(outside_chan_ids) print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") - print("\t> step 3: do common median referencing.") - cmr = spre.common_reference( + print("\t> step 3: do common average referencing.") + car = spre.common_reference( rec_clean, + operator="average", # not median cuz all bad channels are removed dtype=np.int16, # make sure output is int16 ) - return cmr + return car def correct_ap_motion(rec, mc_method="dredge"): From bd27a36c681a7f162884df2b9e95f1f84ee62616 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 13 Aug 2025 14:26:21 +0100 Subject: [PATCH 533/658] remove pre-identified bad channels --- pixels/ioutils.py | 5 +++++ pixels/pixels_utils.py | 18 ++++++++++++++---- pixels/stream.py | 6 ++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index b152fc9..801cd4b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -186,6 +186,11 @@ def get_data_files(data_dir, session_name): f"{mouse_id}_depth_info.yaml" ) + # identified faulty channels + pixels[stream_id]["faulty_channels"] = base_name.with_name( + f"{session_name}_{probe_id}_faulty_channels.yaml" + ) + #pixels[stream_id]["spike_rate_processed"] = base_name.with_name( # f"spike_rate_{stream_id}.h5" #) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9c86b1d..f959067 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -50,7 +50,7 @@ def load_raw(paths, stream_id): return rec -def preprocess_raw(rec, surface_depths): +def preprocess_raw(rec, surface_depths, faulty_channels): group_ids = rec.get_channel_groups() if np.unique(group_ids).size < 4: @@ -68,7 +68,7 @@ def preprocess_raw(rec, surface_depths): logging.info(f"\n> Preprocessing shank {g}") # get brain surface depth of shank surface_depth = surface_depths[g] - cleaned = _preprocess_raw(group, surface_depth) + cleaned = _preprocess_raw(group, surface_depth, faulty_channels[g]) preprocessed.append(cleaned) # aggregate groups together preprocessed = si.aggregate_channels(preprocessed) @@ -78,12 +78,16 @@ def preprocess_raw(rec, surface_depths): # get brain surface depth of shank surface_depth = surface_depths[unique_id] # preprocess - preprocessed = _preprocess_raw(rec, surface_depth) + preprocessed = _preprocess_raw( + rec, + surface_depth, + faulty_channels[unique_id], + ) return preprocessed -def _preprocess_raw(rec, surface_depth): +def _preprocess_raw(rec, surface_depth, faulty_channels): """ Implementation of preprocessing on raw pixels data. """ @@ -93,6 +97,12 @@ def _preprocess_raw(rec, surface_depth): # remove bad channels from sorting print("\t> step 2: remove bad channels.") + # remove pre-identified bad channels + chan_names = rec_ps.get_property("channel_name") + faulty_ids = rec_ps.channel_ids[np.isin(chan_names, faulty_channels)] + rec_ps = rec_ps.remove_channels(faulty_ids) + + # detect bad channels bad_chan_ids, chan_labels = spre.detect_bad_channels( rec_ps, outside_channels_location="top", diff --git a/pixels/stream.py b/pixels/stream.py index 6874bba..bdd36a7 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -886,10 +886,16 @@ def preprocess_raw(self): ) surface_depths = depth_info["raw_signal_depths"][self.stream_id] + # find faulty channels to remove + faulty_channels = file_utils.load_yaml( + path=self.processed / self.files["faulty_channels"], + ) + # preprocess self.files["preprocessed"] = xut.preprocess_raw( raw_rec, surface_depths, + faulty_channels, ) return None From 7e9857ade82c28a9517dc9a374674c7805fed0d3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:18:09 +0100 Subject: [PATCH 534/658] add typical oscillation bands --- pixels/constants.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/constants.py b/pixels/constants.py index f5aa59d..1ce2c99 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -8,6 +8,9 @@ freq_bands = { "ap":[300, 9000], "lfp":[0.5, 500], + "theta":[4, 11], # from Tom + "gamma":[30, 80], # from Tom + "ripple":[110, 220], # from Tom } BEHAVIOUR_HZ = 25000 From 9d4f6668871c406b0e0c757559cd28cee61b95b2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:18:53 +0100 Subject: [PATCH 535/658] add bnadwise psd --- pixels/ioutils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 801cd4b..564e078 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -137,6 +137,11 @@ def get_data_files(data_dir, session_name): f"{session_name}_{stream_id}_channel_clustering_results.h5" ) + # psd of theta and ripple to identify CA1 pyramidal layer + pixels[stream_id]["bandwise_psd"] = base_name.with_name( + f"{session_name}_{probe_id}_bandwise_psd.h5" + ) + # TODO mar 5 2025: # maybe do NOT put shuffled data in here, cuz there will be different # trial conditions, better to cache them??? From b7840305ac3680c6dad5d6bd8385c58a327d9501 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:19:20 +0100 Subject: [PATCH 536/658] make name more consistent --- pixels/pixels_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index f959067..fb0fe69 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -94,23 +94,23 @@ def _preprocess_raw(rec, surface_depth, faulty_channels): # correct phase shift print("\t> step 1: do phase shift correction.") rec_ps = spre.phase_shift(rec) - + # remove bad channels from sorting print("\t> step 2: remove bad channels.") # remove pre-identified bad channels chan_names = rec_ps.get_property("channel_name") faulty_ids = rec_ps.channel_ids[np.isin(chan_names, faulty_channels)] - rec_ps = rec_ps.remove_channels(faulty_ids) + rec_removed = rec_ps.remove_channels(faulty_ids) # detect bad channels bad_chan_ids, chan_labels = spre.detect_bad_channels( - rec_ps, + rec_removed, outside_channels_location="top", ) labels, counts = np.unique(chan_labels, return_counts=True) for label, count in zip(labels, counts): print(f"\t\t> Found {count} channels labelled as {label}.") - rec_removed = rec_ps.remove_channels(bad_chan_ids) + rec_removed = rec_removed.remove_channels(bad_chan_ids) # get channel group id and use it to index into brain surface channel depth shank_id = np.unique(rec_removed.get_channel_groups())[0] From 2ea59f3a9f2d655f72b73c178abde25cda16ff7a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:19:43 +0100 Subject: [PATCH 537/658] do NOT common average referencing in preprocessing, do it on demand --- pixels/pixels_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index fb0fe69..aadaaab 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -123,11 +123,9 @@ def _preprocess_raw(rec, surface_depth, faulty_channels): rec_clean = rec_removed.remove_channels(outside_chan_ids) print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") - print("\t> step 3: do common average referencing.") - car = spre.common_reference( - rec_clean, - operator="average", # not median cuz all bad channels are removed - dtype=np.int16, # make sure output is int16 + return rec_clean + + ) return car From 1881be6dd0b9fcdcb9626160913f50fa506dcc48 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:20:01 +0100 Subject: [PATCH 538/658] separate car and cmr --- pixels/pixels_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index aadaaab..4e67946 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -126,10 +126,22 @@ def _preprocess_raw(rec, surface_depth, faulty_channels): return rec_clean +def CMR(rec, dtype=np.int16): + cmr = spre.common_reference( + rec, + operator="median", + dtype=dtype, ) + return cmr - return car +def CAR(rec, dtype=np.int16): + car = spre.common_reference( + rec, + operator="average", + dtype=np.int16, + ) + return car def correct_ap_motion(rec, mc_method="dredge"): """ From 58046e22758b18f1cf3b1e24694d683ba9df78c8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:20:16 +0100 Subject: [PATCH 539/658] notch a narrow band freq with si --- pixels/pixels_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4e67946..75f5c88 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1507,3 +1507,29 @@ def save_chance_psd(sample_rate, positions, paths):#chance_data, idx, cols): ) return psds + + +def notch_freq(rec, freq, bw=4.0): + """ + Notch a frequency with narrow bandwidth. + + params + === + rec: si recording object. + + freq: float or int, the target frequency in Hz of the notch filter. + + bw: float or int, bandwidth (Hz) of notch filter. + Default: 4.0Hz. + + return + === + notched: spikeinterface recording object. + """ + notched = spre.notch_filter( + rec, + freq=freq, + q=freq/bw, # quality factor + ) + + return notched From 57979933d1f7968cac42fa28d2ae35305106c866 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:20:48 +0100 Subject: [PATCH 540/658] implement notch freq from scratch --- pixels/signal_utils.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pixels/signal_utils.py b/pixels/signal_utils.py index d421638..abcd3a1 100644 --- a/pixels/signal_utils.py +++ b/pixels/signal_utils.py @@ -437,3 +437,38 @@ def motion_index(video, rois): mi = mi / mi.max(axis=0) return mi + + +def freq_notch(x, fs, w0, axis=0, bw=4.0): + """ + Use a notch filter that is a band-stop filter with a narrow bandwidth + between 48 to 52Hz. + It rejects a narrow frequency band and leaves the rest of the spectrum + little changed. + + params + === + x: array like, data to filter. + + fs: float or int, sampling frequency of x. + + w0: float or int, target frequency to notch (Hz). + + axis: int, axis of x to apply filter. + + bw: float or int, bandwidth of notch filter. + + return + === + notched: array like, notch filtered x. + """ + # convert to float + x = x.astype(np.float32, copy=False) + # set quality factor + Q = w0 / bw + # get numerator b & denominator a of IIR filter + b, a = scipy.signal.iirnotch(w0, Q, fs=fs) + # apply digital filter forward and backward + notched = scipy.signal.filtfilt(b, a, x, axis=axis) + + return notched From a394c58e8551655f3498a7bcc7eb4180ed8b3198 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:49:40 +0100 Subject: [PATCH 541/658] notch 50hz noise in lfp before car --- pixels/stream.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index bdd36a7..e26c98f 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -926,6 +926,16 @@ def extract_bands(self, freqs, preprocess): freq_max=freqs[1], ) + if "lfp" in name: + noise_freq = 50 # Hz + logging.info( + f"\n> Notching {noise_freq} Hz noise on {name} band." + ) + notched = xut.notch_freq( + rec=extracted, + freq=noise_freq, + ) + extracted = notched return None From ba24caffe6eb861c0573791bebfd22cdceeb1f68 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 15 Aug 2025 18:50:00 +0100 Subject: [PATCH 542/658] bandpass before car --- pixels/stream.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index e26c98f..33ef815 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -920,7 +920,7 @@ def extract_bands(self, freqs, preprocess): f"\n> Extracting {name} bands from {self.stream_id}." ) # do bandpass filtering - self.files[f"{name}_extracted"] = xut.extract_band( + extracted = xut.extract_band( rec, freq_min=freqs[0], freq_max=freqs[1], @@ -936,6 +936,12 @@ def extract_bands(self, freqs, preprocess): freq=noise_freq, ) extracted = notched + + logging.info( + f"\n> Common average referencing {name} band." + ) + self.files[f"{name}_extracted"] = xut.CAR(extracted) + return None From bd617992b402d84d56a18fe54b458cba678b70b3 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 19 Aug 2025 16:17:53 +0100 Subject: [PATCH 543/658] do not notch 50hz by default cuz gamma is meaningful... --- pixels/stream.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 33ef815..5f335be 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -926,17 +926,6 @@ def extract_bands(self, freqs, preprocess): freq_max=freqs[1], ) - if "lfp" in name: - noise_freq = 50 # Hz - logging.info( - f"\n> Notching {noise_freq} Hz noise on {name} band." - ) - notched = xut.notch_freq( - rec=extracted, - freq=noise_freq, - ) - extracted = notched - logging.info( f"\n> Common average referencing {name} band." ) From 9865b225ca595ef0061661ab58f9531481714b2c Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 19 Aug 2025 16:18:17 +0100 Subject: [PATCH 544/658] do preprocessing by default --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 5f335be..407b0da 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -901,7 +901,7 @@ def preprocess_raw(self): return None - def extract_bands(self, freqs, preprocess): + def extract_bands(self, freqs, preprocess=True): if preprocess: self.preprocess_raw() rec = self.files["preprocessed"] From 36e3f9dffe9a3efa806021c4d5fbf7982230b738 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 19 Aug 2025 16:18:49 +0100 Subject: [PATCH 545/658] use logging --- pixels/experiment.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index bf5ca40..7b6809d 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -135,8 +135,10 @@ def extract_ap(self): def sort_spikes(self, mc_method="dredge"): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): - print(">>>>> Sorting spikes for session {} ({} / {})" - .format(session.name, i + 1, len(self.sessions))) + logging.info( + "\n>>>>> Sorting spikes for session " + f"{session.name} ({i + 1} / {len(self.sessions)})" + ) session.sort_spikes(mc_method=mc_method) def assess_noise(self): From 21e1672831920f83252e1bd61c2a0e0dbf5718bf Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 19 Aug 2025 16:19:01 +0100 Subject: [PATCH 546/658] use ks mc by default --- pixels/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index 7b6809d..cddd7ce 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -132,7 +132,7 @@ def extract_ap(self): .format(session.name, i + 1, len(self.sessions))) session.extract_ap() - def sort_spikes(self, mc_method="dredge"): + def sort_spikes(self, mc_method="ks"): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): logging.info( From ff3ddec2d657beae3238e1ce4ded58339fecc954 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 2 Sep 2025 17:48:08 +0100 Subject: [PATCH 547/658] copy processed data to backup if backup exists --- pixels/behaviours/base.py | 10 ++++++++++ pixels/stream.py | 19 +++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a1e9928..d7b3f52 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -90,6 +90,7 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, self.processed = self.data_dir / 'processed' / self.name else: self.processed = Path(processed_dir).expanduser() / self.name + self.backup = self.data_dir / 'processed' / self.name if hist_dir is None: self.histology = self.data_dir / 'histology'\ @@ -633,6 +634,11 @@ def detect_n_localise_peaks(self, loc_method="monopolar_triangulation"): # write to disk ioutils.write_hdf5(output, df) + if hasattr(self, "backup"): + # copy to backup if backup setup + copyfile(output, self.backup / output.name) + logging.info(f"\n> Detected peaks copied to {self.backup}.") + return None @@ -854,6 +860,10 @@ def sort_spikes(self, mc_method="ks"): output=output, sa_dir=sa_dir, ) + if hasattr(self, "backup"): + # copy to backup if backup setup + copyfile(output, self.backup / output.name) + logging.info(f"\n> Sorter ourput copied to {self.backup}.") return None diff --git a/pixels/stream.py b/pixels/stream.py index 407b0da..a4ce21e 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -618,9 +618,9 @@ def _sync_vr(self, vr_session): pixels_idx, ) + synched_vr_file = self.behaviour_files["vr_synched"][self.stream_num] file_utils.write_hdf5( - self.processed /\ - self.behaviour_files["vr_synched"][self.stream_num], + self.processed / synched_vr_file, synched_vr, ) @@ -640,6 +640,21 @@ def _sync_vr(self, vr_session): ) logging.info(f"\n> Action labels saved to: {action_labels_path}.") + if hasattr(self.session, "backup"): + # copy to backup if backup setup + copyfile( + self.processed / synched_vr_file, + self.session.backup / synched_vr_file, + ) + copyfile( + action_labels_path, + self.session.backup / action_labels_path.name, + ) + logging.info( + "\n> Syched vr and action labels copied to: " + f"{self.session.backup}." + ) + return None From 95d1fab7cf88d6a800300a5666c0be96c6308355 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 2 Sep 2025 17:50:47 +0100 Subject: [PATCH 548/658] update spikeinterface in ks4 image --- pixels/configs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/configs.py b/pixels/configs.py index 98999a2..da48253 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -33,6 +33,7 @@ ) # kilosort 4 singularity image names -ks4_0_30_image_name = "si102.3_ks4-0-30_with_wavpack.sif" +ks4_0_30_image_name = "si103.0_ks4-0-30_with_wavpack.sif" +#ks4_0_30_image_name = "si102.3_ks4-0-30_with_wavpack.sif" ks4_0_18_image_name = "ks4-0-18_with_wavpack.sif" ks4_image_name = ks4_0_30_image_name From 2de9877011e37a3d355ac8bad8c60937abd1cf47 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 2 Sep 2025 17:51:26 +0100 Subject: [PATCH 549/658] add multiprocessing options --- pixels/configs.py | 10 ++++++---- pixels/stream.py | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pixels/configs.py b/pixels/configs.py index da48253..adaa936 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -17,12 +17,14 @@ # set si job_kwargs job_kwargs = dict( - #mp_context="fork", # linux, but does not work still on 2025 aug 12 - mp_context="spawn", # mac & win - chunk_duration='1s', + #pool_engine="thread", # instead of default "process" + pool_engine="process", + mp_context="fork", # linux, but does not work still on 2025 aug 12 + #mp_context="spawn", # mac & win progress_bar=True, n_jobs=0.8, - max_threads_per_worker=1, + chunk_duration='1s', + max_threads_per_worker=8, ) si.set_global_job_kwargs(**job_kwargs) diff --git a/pixels/stream.py b/pixels/stream.py index a4ce21e..0daa35b 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1,4 +1,5 @@ import gc +from shutil import copyfile import numpy as np import pandas as pd From c7a8ab9ed34a7aa23a634a514fa7837632d56e69 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 2 Sep 2025 17:53:56 +0100 Subject: [PATCH 550/658] add doc --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 0daa35b..11997fc 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -991,7 +991,7 @@ def sort_spikes(self, ks_mc, ks4_params, ks_image_path, output, sa_dir): return === - + None """ # use only preprocessed if use ks motion correction if ks_mc: From d59ba1f1ea511b688b370af48ba3846927c5754b Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 2 Sep 2025 17:54:48 +0100 Subject: [PATCH 551/658] copy to local by default otherwise too slow --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 11997fc..a4f8dcf 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -44,7 +44,7 @@ def __repr__(self): def load_raw_ap(self): - paths = [self.session.find_file(path, copy=False) for path in self.files["ap_raw"]] + paths = [self.session.find_file(path, copy=True) for path in self.files["ap_raw"]] self.files["si_rec"] = xut.load_raw(paths, self.stream_id) return self.files["si_rec"] From 7bb0a4a62088b2f643e5da736940c05795dd4f19 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 3 Sep 2025 11:55:51 +0100 Subject: [PATCH 552/658] use copytree to copy directory --- pixels/behaviours/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d7b3f52..dc4cd90 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path -from shutil import copyfile +from shutil import copyfile, copytree from typing import TYPE_CHECKING import matplotlib.pyplot as plt @@ -559,6 +559,12 @@ def correct_ap_motion(self, mc_method="dredge"): compressor=wv_compressor, ) + if hasattr(self, "backup"): + # copy to backup if backup setup + # zarr is a directory hence use copytree + copytree(output, self.backup / output.name) + logging.info(f"\n> Motion corrected copied to {self.backup}.") + return None @@ -862,7 +868,7 @@ def sort_spikes(self, mc_method="ks"): ) if hasattr(self, "backup"): # copy to backup if backup setup - copyfile(output, self.backup / output.name) + copytree(output, self.backup / output.name) logging.info(f"\n> Sorter ourput copied to {self.backup}.") return None From 064188e5a20410ddeef7076a74d58561d734dca9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Sep 2025 13:18:51 +0100 Subject: [PATCH 553/658] use `dredge_ap` as default ap motion correction method --- pixels/behaviours/base.py | 2 +- pixels/experiment.py | 2 +- pixels/pixels_utils.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index dc4cd90..79a9265 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -511,7 +511,7 @@ def process_behaviour(self): logging.info("\n> Done!") - def correct_ap_motion(self, mc_method="dredge"): + def correct_ap_motion(self, mc_method="dredge_ap"): """ Correct motion of recording. diff --git a/pixels/experiment.py b/pixels/experiment.py index cddd7ce..d2f2757 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -132,7 +132,7 @@ def extract_ap(self): .format(session.name, i + 1, len(self.sessions))) session.extract_ap() - def sort_spikes(self, mc_method="ks"): + def sort_spikes(self, mc_method="dredge_ap"): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): logging.info( diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 75f5c88..dd3529f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -144,13 +144,14 @@ def CAR(rec, dtype=np.int16): return car def correct_ap_motion(rec, mc_method="dredge"): +def correct_ap_motion(rec, mc_method="dredge_ap"): """ Correct motion of recording. params === mc_method: str, motion correction method. - Default: "dredge". + Default: "dredge_ap". (as of jan 2025, dredge performs better than ks motion correction.) "ks": let kilosort do motion correction. @@ -161,9 +162,8 @@ def correct_ap_motion(rec, mc_method="dredge"): logging.info(f"\n> Correcting motion with {mc_method}.") # reduce spatial window size for four-shank - # TODO may 8 2025 "method":"dredge_ap" after it's implemented? estimate_motion_kwargs = { - "method": "decentralized", + "method": "dredge_ap", "win_step_um": 100, "win_margin_um": -150, "verbose": True, From 14ac5f85e2deb648c505ffbcfeb1f6d476e2ae9e Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Sep 2025 13:19:19 +0100 Subject: [PATCH 554/658] and set default motion correction method for lfp --- pixels/pixels_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index dd3529f..61fa6be 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -143,7 +143,11 @@ def CAR(rec, dtype=np.int16): ) return car -def correct_ap_motion(rec, mc_method="dredge"): + +def correct_lfp_motion(rec, mc_method="dredge_lfp"): + raise NotImplementedError("> Not implemented.") + + def correct_ap_motion(rec, mc_method="dredge_ap"): """ Correct motion of recording. From 1c9508137ee02e2d57aadca5421148e8d4bd2ae8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 5 Sep 2025 13:19:58 +0100 Subject: [PATCH 555/658] if backup exists find file in backup first --- pixels/behaviours/base.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 79a9265..cb4cd5d 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -221,7 +221,20 @@ def find_file(self, name: str, copy: bool=True) -> Optional[Path]: pathlib.Path : the full path to the desired file in the correct folder. """ + backup = self.backup / name processed = self.processed / name + if hasattr(self, "backup") and backup.exists()\ + and (not processed.exists()): + if copy: + logging.info( + f"\n {self.name}: Copying {name} to local processed" + ) + try: + copyfile(backup, processed) + except IsADirectoryError: + copytree(backup, processed) + return processed + if processed.exists(): return processed @@ -534,11 +547,13 @@ def correct_ap_motion(self, mc_method="dredge_ap"): streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - output = self.processed / stream_files["ap_motion_corrected"] - if output.exists(): + output = self.find_file(stream_files["ap_motion_corrected"]) + if output: logging.info(f"\n> {stream_id} already motion corrected.") stream_files["ap_motion_corrected"] = si.load(output) continue + else: + output = self.processed / stream_files["ap_motion_corrected"] logging.info( f"\n>>>>> Correcting motion for ap band from {stream_id} " @@ -837,6 +852,9 @@ def sort_spikes(self, mc_method="ks"): streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): # check if already sorted and exported + sorter_output = self.find_file( + stream_files["sorting_analyser"].parent + ) sa_dir = self.processed / stream_files["sorting_analyser"] if sa_dir.exists(): logging.info("\n> Already sorted and exported, next stream.") From 53f716871becf44f3f13d551bb1120fdcde61e2b Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 10 Sep 2025 10:47:53 +0100 Subject: [PATCH 556/658] make sure backup dir exists --- pixels/behaviours/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index cb4cd5d..20ae66f 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -89,8 +89,9 @@ def __init__(self, name, data_dir, metadata=None, processed_dir=None, if processed_dir is None: self.processed = self.data_dir / 'processed' / self.name else: - self.processed = Path(processed_dir).expanduser() / self.name - self.backup = self.data_dir / 'processed' / self.name + self.processed = Path(processed_dir).expanduser() / self.name + self.backup = self.data_dir / 'processed' / self.name + self.backup.mkdir(parents=True, exist_ok=True) if hist_dir is None: self.histology = self.data_dir / 'histology'\ From 49a7f4daf0ae3c2c9809b9059920f01cd337a7a2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 10 Sep 2025 11:12:19 +0100 Subject: [PATCH 557/658] typo in x-lim of shank 2 --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 61fa6be..9df254c 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1188,7 +1188,7 @@ def correct_group_id(rec): shank_x_locs = { 0: [0, 32], 1: [250, 282], - 2: [500, 582], + 2: [500, 532], 3: [750, 782], } From 2ef0ed5d16b18a7dbea3cdb0a4422c67dc9df21f Mon Sep 17 00:00:00 2001 From: amac Date: Mon, 15 Sep 2025 16:48:37 +0100 Subject: [PATCH 558/658] use Worlds property to get dark worlds bool --- pixels/behaviours/virtual_reality.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 1c12dab..8954b6e 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -155,10 +155,9 @@ def _extract_action_labels(self, vr, vr_data): # define in gray in_gray = (vr_data.world_index == Worlds.GRAY) # define in dark - in_dark = vr_data.world_index.isin(Worlds.DARKS) - #in_dark = (vr_data.world_index == Worlds.DARK_5)\ - # | (vr_data.world_index == Worlds.DARK_2_5)\ - # | (vr_data.world_index == Worlds.DARK_FULL) + in_dark = vr_data.world_index.isin( + {w.value for w in Worlds if w.is_dark} + ) # define in white in_white = (vr_data.world_index == Worlds.WHITE) # define in tunnel @@ -489,7 +488,9 @@ def _get_world_masks(self, df: pd.DataFrame) -> WorldMasks: # define in gray in_gray = (df.world_index == Worlds.GRAY) # define in dark - in_dark = (df.world_index.isin(Worlds.DARKS)) + in_dark = df.world_index.isin( + {w.value for w in Worlds if w.is_dark} + ) # define in white in_white = (df.world_index == Worlds.WHITE) # define in light From 7d335c2f51d91d7583b151cb61f383150d4098c8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 17 Sep 2025 13:01:45 +0100 Subject: [PATCH 559/658] remove old version code --- pixels/behaviours/virtual_reality.py | 336 --------------------------- 1 file changed, 336 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 8954b6e..e44d2be 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -126,342 +126,6 @@ class Events(IntFlag): #run_stop = auto() -# map trial outcome -_outcome_map = { - Outcomes.ABORTED_DARK: "miss_dark", - Outcomes.ABORTED_LIGHT: "miss_light", - Outcomes.TRIGGERED: "triggered", - Outcomes.AUTO_LIGHT: "auto_light", - Outcomes.DEFAULT: "default_light", - Outcomes.REINF_LIGHT: "reinf_light", - Outcomes.AUTO_DARK: "auto_dark", - Outcomes.REINF_DARK: "reinf_dark", -} - -# function to look up trial type -trial_type_lookup = {v: k for k, v in vars(Conditions).items()} - -''' -class VR(Behaviour): - - def _extract_action_labels(self, vr, vr_data): - # NOTE: this func still called _extract_action_labels cuz it is - # inherited from motor data analysis, where each unit is an action. - - events_arr = np.zeros(len(vr_data), dtype=np.uint32) - outcomes_arr = np.zeros(len(vr_data), dtype=np.uint32) - - # >>>> definitions >>>> - # define in gray - in_gray = (vr_data.world_index == Worlds.GRAY) - # define in dark - in_dark = vr_data.world_index.isin( - {w.value for w in Worlds if w.is_dark} - ) - # define in white - in_white = (vr_data.world_index == Worlds.WHITE) - # define in tunnel - in_tunnel = ~in_gray & ~in_white - # define in light - in_light = (vr_data.world_index == Worlds.TUNNEL) - # define light & dark trials - trial_light = (vr_data.trial_type == Conditions.LIGHT) - trial_dark = (vr_data.trial_type == Conditions.DARK) - # <<<< definitions <<<< - - logging.info("\n>> Mapping vr event times...") - - # >>>> gray >>>> - # get timestamps of gray - gray_idx = vr_data.world_index[in_gray].index - # get first grays - grays = np.where(gray_idx.diff() != 1)[0] - - # find time for first frame of gray - gray_on_t = gray_idx[grays] - # find their index in vr data - gray_on = vr_data.index.get_indexer(gray_on_t) - # bitwise_or.at will do: for each idx in gray_on: dst[idx] |= Events.gray_on - np.bitwise_or.at(events_arr, gray_on, Events.gray_on) - - # find time for last frame of gray - gray_off_t = np.append(gray_idx[grays[1:] - 1], gray_idx[-1]) - # find their index in vr data - gray_off = vr_data.index.get_indexer(gray_off_t) - np.bitwise_or.at(events_arr, gray_off, Events.gray_off) - # <<<< gray <<<< - - # >>>> punishment >>>> - # get timestamps of punishment - punish_idx = vr_data[in_white].index - # get first punishment - punishes = np.where(punish_idx.diff() != 1)[0] - - # find time for first frame of punishment - punish_on_t = punish_idx[punishes] - # find their index in vr data - punish_on = vr_data.index.get_indexer(punish_on_t) - np.bitwise_or.at(events_arr, punish_on, Events.punish_on) - - # find time for last frame of punish - punish_off_t = np.append(punish_idx[punishes[1:] - 1], punish_idx[-1]) - # find their index in vr data - punish_off = vr_data.index.get_indexer(punish_off_t) - np.bitwise_or.at(events_arr, punish_off, Events.punish_off) - # <<<< punishment <<<< - - # >>>> trial ends >>>> - # trial ends right before punishment starts - np.bitwise_or.at(events_arr, punish_on-1, Events.trial_end) - - # for non punished trials, right before gray on is when trial ends, and - # the last frame of the session - pre_gray_on_idx = np.append(gray_on[1:] - 1, vr_data.shape[0] - 1) - pre_gray_on = vr_data.iloc[pre_gray_on_idx] - # drop punish_off times - no_punished_t = pre_gray_on.drop(punish_off_t).index - # get index of trial ends in non punished trials - no_punished_idx = vr_data.index.get_indexer(no_punished_t) - np.bitwise_or.at(events_arr, no_punished_idx, Events.trial_end) - # <<<< trial ends <<<< - - # >>>> light >>>> - # get index of data in light tunnel - light_idx = vr_data[in_light].index - # get where light turns on - lights = np.where(light_idx.diff() != 1)[0] - # get timepoint of when light turns on - light_on_t = light_idx[lights] - # get index of when light turns on - light_on = vr_data.index.get_indexer(light_on_t) - np.bitwise_or.at(events_arr, light_on, Events.light_on) - - # get interval of possible starting position - start_interval = int(vr.meta_item('rand_start_int')) - - # find starting position in all light_on - trial_starts = light_on[np.where( - vr_data.iloc[light_on].position_in_tunnel % start_interval == 0 - )[0]] - # label trial starts - np.bitwise_or.at(events_arr, trial_starts, Events.trial_start) - - if not trial_starts.size == vr_data[in_tunnel].trial_count.max(): - raise PixelsError(f"Number of trials does not equal to " - "{vr_data.trial_count.max()}.") - # NOTE: if trial starts at 0, the first position_in_tunnel value will - # NOT be nan - - # last frame of light - light_off_t = np.append(light_idx[lights[1:] - 1], light_idx[-1]) - light_off = vr_data.index.get_indexer(light_off_t) - np.bitwise_or.at(events_arr, light_off, Events.light_off) - # <<<< light <<<< - - # NOTE: dark trials should in theory have EQUAL index pre_dark_end_idx - # and dark_on, BUT! after interpolation, some dark trials have their - # dark onset earlier than expected, those are corrected to the first - # expected position, they will have the same index as pre_dark_end_idx. - # others will not since their world_index change later than expected. - - # NOTE: if dark trial is aborted, light tunnel only turns off once; but - # if it is a reward is dispensed, light tunnel turns off twice - - # NOTE: number of dark_on does not match with number of dark trials - # caused by triggering punishment before dark - - # >>>> dark >>>> - # get index in dark - dark_idx = vr_data[in_dark].index - darks = np.where(dark_idx.diff() != 1)[0] - - # first frame of dark - dark_on_t = dark_idx[darks] - dark_on = vr_data.index.get_indexer(dark_on_t) - np.bitwise_or.at(events_arr, dark_on, Events.dark_on) - - # last frame of dark - dark_off_t = np.append(dark_idx[darks[1:] - 1], dark_idx[-1]) - dark_off = vr_data.index.get_indexer(dark_off_t) - np.bitwise_or.at(events_arr, dark_off, Events.dark_off) - # <<<< dark <<<< - - # >>>> licks >>>> - lick_onsets = np.diff(vr_data.lick_detect, prepend=0) - licked_idx = np.where(lick_onsets == 1)[0] - np.bitwise_or.at(events_arr, licked_idx, Events.licked) - # <<<< licks <<<< - - # TODO jun 27 2024 positional events and valve events needs mapping - # >>>> positional event mapping >>>> - - # >>>> Event: end of pre dark length >>>> - # NOTE: AL remove pre_dark_len + 10cm of his data - # get starting positions of all trials - start_pos = vr_data[in_tunnel].groupby( - "trial_count" - )["position_in_tunnel"].first() - - # plus pre dark - end_pre_dark = start_pos + vr.pre_dark_len - - def _first_post_pre_dark(df): - trial = df.name - pre_dark_end = end_pre_dark.loc[trial] - # mask and pick the first index - mask = df['position_in_tunnel'] >= pre_dark_end - if not mask.any(): - return None - return df.index[mask].min() - - pre_dark_end_t = vr_data[in_tunnel].groupby("trial_count").apply( - _first_post_pre_dark - ) - pre_dark_end_idx = vr_data.index.get_indexer( - pre_dark_end_t.dropna().astype(int) - ) - np.bitwise_or.at(events_arr, pre_dark_end_idx, Events.pre_dark_end) - # <<<< Event: end of pre dark length <<<< - - # >>>> Event: landmark 1 >>>> - # jun 3 2025: - # CONTINUE HERE! - # might as well implement the new refactored code before extend more - # landmarks... - - # <<<< Event: landmark 1 <<<< - - # >>>> Event: reward zone >>>> - # all indices in reward zone - in_zone = ( - vr_data.position_in_tunnel >= vr.reward_zone_start - ) & ( - vr_data.position_in_tunnel <= vr.reward_zone_end - ) - # first frame in reward zone - zone_on_t = ( - vr_data[in_zone] - .groupby("trial_count") - .apply(lambda g: g.index.min()) - ) - zone_on_idx = vr_data.index.get_indexer(zone_on_t) - np.bitwise_or.at(events_arr, zone_on_idx, Events.reward_zone_on) - - # last frame in reward zone - zone_off_t = ( - vr_data[in_zone] - .groupby("trial_count") - .apply(lambda g: g.index.max()) - ) - zone_off_idx = vr_data.index.get_indexer(zone_off_t) - np.bitwise_or.at(events_arr, zone_off_idx, Events.reward_zone_off) - # <<<< Event: reward zone <<<< - - logging.info("\n>> Mapping vr action times...") - - # get non-zero reward types - reward_not_none = (vr_data.reward_type != Outcomes.NONE) - - for t, trial in enumerate(vr_data.trial_count.unique()): - # get current trial - of_trial = (vr_data.trial_count == trial) - # get index of current trial - trial_idx = np.where(of_trial)[0] - # get start index of current trial - start_idx = trial_idx[np.isin(trial_idx, trial_starts)] - - # find where is non-zero reward type in current trial - reward_typed = vr_data[of_trial & reward_not_none] - # get trial type of current trial - trial_type = int(vr_data[of_trial].trial_type.unique()) - # get name of trial type in string - trial_type_str = trial_type_lookup.get(trial_type).lower() - - # >>>> map reward types >>>> - - # >>>> punished >>>> - if (reward_typed.size == 0)\ - & (vr_data[of_trial & in_white].size != 0): - # punished outcome - outcome = f"punished_{trial_type_str}" - outcomes_arr[trial_idx] = getattr(TrialTypes, outcome) - # or only mark the beginning of the trial? - #outcomes_arr[start_idx] = getattr(TrialTypes, outcome) - # <<<< punished <<<< - - elif (reward_typed.size == 0)\ - & (vr_data[of_trial & in_white].size == 0): - # >>>> unfinished trial >>>> - # double check it is the last trial - assert (trial == vr_data.trial_count.unique().max()) - assert (vr_data[of_trial].position_in_tunnel.max()\ - < vr.tunnel_reset) - logging.info(f"\n> trial {trial} is unfinished when session " - "ends, so there is no outcome.") - # <<<< unfinished trial <<<< - else: - # >>>> non punished >>>> - # get non-zero reward type in current trial - reward_type = int(reward_typed.reward_type.unique()) - # double check reward_type is in outcome map - assert (reward_type in _outcome_map) - - """ triggered """ - # catch triggered trials and separate trial types - if reward_type == Outcomes.TRIGGERED: - outcome = f"{_outcome_map[reward_type]}_{trial_type_str}" - else: - """ given & aborted """ - outcome = _outcome_map[reward_type] - # label outcome - outcomes_arr[trial_idx] = getattr(TrialTypes, outcome) - # or only mark the beginning of the trial? - #outcomes_arr[start_idx] = getattr(TrialTypes, outcome) - # <<<< non punished <<<< - - # >>>> non aborted, valve only >>>> - # if not aborted, map valve open & closed - if reward_type > Outcomes.NONE: - # map valve open - valve_open_idx = vr_data.index.get_indexer([reward_typed.index[0]]) - np.bitwise_or.at(events_arr, valve_open_idx, Events.valve_open) - # map valve closed - valve_closed_idx = vr_data.index.get_indexer( - [reward_typed.index[-1]] - ) - np.bitwise_or.at(events_arr, valve_closed_idx, Events.valve_closed) - # <<<< non aborted, valve only <<<< - - # <<<< map reward types <<<< - - # return typed arrays - return LabeledEvents( - outcome = outcomes_arr, - events = events_arr, - timestamps = vr_data.index.values, - ) - - - def _check_action_labels(self, vr_data, action_labels, plot=True): - - # TODO jun 9 2024 make this work, save the plot - if plot: - plt.clf() - _, axes = plt.subplots(4, 1, sharex=False, sharey=False) - axes[0].plot(back_sensor_signal) - if "/'Back_Sensor'/'0'" in behavioural_data: - axes[1].plot(behavioural_data["/'Back_Sensor'/'0'"].values) - else: - axes[1].plot(behavioural_data["/'ReachCue_LEDs'/'0'"].values) - axes[2].plot(action_labels[:, 0]) - axes[3].plot(action_labels[:, 1]) - plt.plot(action_labels[:, 1]) - plt.show() - - return action_labels -''' - - class LabeledEvents(NamedTuple): """Return type: timestamps + bitfields for outcome & events.""" timestamps: np.ndarray # shape (N,) From c8fbf24a6800db0754073374fc1e15cc679a61a0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 17 Sep 2025 13:02:12 +0100 Subject: [PATCH 560/658] first try find_file --- pixels/behaviours/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 20ae66f..fae839e 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1413,7 +1413,11 @@ def _get_processed_data(self, attr, key, category): if key in files: dirs = files[key] for f, file_dir in enumerate(dirs): - file_path = self.processed / file_dir + file_path = self.find_file(file_dir) + try: + assert (file_path is not None) + except: + file_path = self.processed / file_dir if file_path.exists(): if re.search(r'\.np[yz]$', file_path.suffix): saved[f] = np.load(file_path) From d5682c4b5e6e2eb3ade62c9838b9b96e016eaeac Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 18 Sep 2025 13:53:03 +0100 Subject: [PATCH 561/658] add label at pure luminance evoked activity post dark onset --- pixels/behaviours/virtual_reality.py | 10 ++++++++++ pixels/constants.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index e44d2be..8294162 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -19,6 +19,7 @@ from pixels import PixelsError from pixels.behaviours import Behaviour from pixels.configs import * +from pixels.constants import SAMPLE_RATE, V1_SPIKE_LATENCY class TrialTypes(IntFlag): @@ -125,6 +126,9 @@ class Events(IntFlag): #run_start = auto() #run_stop = auto() + # temporal events in dark only + dark_luminance_off = auto()# 500 ms + class LabeledEvents(NamedTuple): """Return type: timestamps + bitfields for outcome & events.""" @@ -406,6 +410,12 @@ def _first_post_mark(group_df, check_marks): masks[Events.pre_dark_end] = pre_dark_end_t # >>> distance travelled before dark onset per trial >>> + # >>> end of luminance change after dark onset >>> + n_frames = int(V1_SPIKE_LATENCY * SAMPLE_RATE / 1000) + dark_luminance_off_t = dark_on_t + n_frames + masks[Events.dark_luminance_off] = dark_luminance_off_t + # <<< end of luminance change after dark onset <<< + # >>> landmark 0 black wall >>> black_off = session.landmarks[0] diff --git a/pixels/constants.py b/pixels/constants.py index 1ce2c99..1d5cdf4 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -18,3 +18,9 @@ np.random.seed(BEHAVIOUR_HZ) REPEATS = 100 + +# latency of luminance change evoked spike +# NOTE: usually it should be between 40-60ms. now we use 500ms just to test +V1_SPIKE_LATENCY = 500 # ms #60 +V1_LFP_LATENCY = 40 # ms +LGN_SPIKE_LATENCY = 30 # ms From b6ce5f62697d73e99ffd24d7d06c8a5a9ebd666b Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 18 Sep 2025 13:54:21 +0100 Subject: [PATCH 562/658] add logging --- pixels/decorators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/decorators.py b/pixels/decorators.py index 17aa3cb..7e9a654 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -50,6 +50,7 @@ def wrapper(*args, **kwargs): logging.info(f"\n> Cache loaded from {cache_path}.") except HDF5ExtError: df = None + logging.info("\n> df is None, cache does not exist.") except (KeyError, ValueError): # if key="df" is not found, then use HDFStore to list and read # all dfs From 7aa3a318b21891f262356eeb95937ee930898800 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 18 Sep 2025 13:54:38 +0100 Subject: [PATCH 563/658] use double quote --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index a4f8dcf..b05f6bc 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -550,7 +550,7 @@ def sync_vr(self, vr_session): # get action labels & synched vr path action_labels = self.session.get_action_labels()[self.stream_num] synched_vr_path = self.session.find_file( - self.behaviour_files['vr_synched'][self.stream_num], + self.behaviour_files["vr_synched"][self.stream_num], ) if action_labels and synched_vr_path: logging.info(f"\n> {self.stream_id} from {self.session.name} is " From 805ddc7030283f73004144a643555f53bb61859d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 18 Sep 2025 13:54:55 +0100 Subject: [PATCH 564/658] only write if it does not exist --- pixels/stream.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index b05f6bc..0a2acbb 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -620,10 +620,13 @@ def _sync_vr(self, vr_session): ) synched_vr_file = self.behaviour_files["vr_synched"][self.stream_num] - file_utils.write_hdf5( - self.processed / synched_vr_file, - synched_vr, - ) + try: + assert self.session.find_file(synched_vr_file) + except: + file_utils.write_hdf5( + self.processed / synched_vr_file, + synched_vr, + ) # get action label dir action_labels_path = self.processed /\ From b948db50408a9c7a7ccef415c5c3a462eb06167e Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 18 Sep 2025 13:55:36 +0100 Subject: [PATCH 565/658] add todo --- pixels/stream.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 0a2acbb..ec7d438 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -404,6 +404,9 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): trials_positions = {} for i, start in enumerate(selected_starts): assert 0 + # TODO sep 16 2025: + # make sure it does not fail for event that happens more than once, + # e.g., licks # select spike times of event in current trial trial_bool = (rec_spikes >= scan_starts[i])\ & (rec_spikes <= scan_ends[i]) From 12d2eec1d50006992576cc948947a84e5dfc2495 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:09:15 +0100 Subject: [PATCH 566/658] allows to add stream_id as key --- pixels/units.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/pixels/units.py b/pixels/units.py index 366aee9..bcbb4d3 100644 --- a/pixels/units.py +++ b/pixels/units.py @@ -1,12 +1,33 @@ -class SelectedUnits(list): +from typing import Iterable + +class SelectedUnits(dict[str, list[int]]): name: str """ - This is the return type of Behaviour.select_units, which is a list in every way - except that when represented as a string, it can return a name, if a `name` - attribute has been set on it. This allows methods that have had `units` passed to be - cached to file. + A mapping from stream_id to lists of unit IDs. + Behaves like a dict in every way, except that when represented as a string, + it can return a `name` if set. This allows named instances to be cached to file. """ - def __repr__(self): + + def __init__(self, *args, name: str | None = None, **kwargs): + super().__init__(*args, **kwargs) + if name is not None: + self.name = name + + def __repr__(self) -> str: if hasattr(self, "name"): return self.name - return list.__repr__(self) + return dict.__repr__(self) + + # Convenience helpers + def add(self, stream_id: str, unit_id: int) -> None: + self.setdefault(stream_id, []).append(unit_id) + + def extend(self, stream_id: str, unit_ids: Iterable[int]) -> None: + self.setdefault(stream_id, []).extend(unit_ids) + + def flat(self) -> list[int]: + # If you sometimes need a flat list of all unit IDs + out: list[int] = [] + for ids in self.values(): + out.extend(ids) + return out From 4d308b7c50a35a50dddabaafc80ad0545bd0d5ef Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:11:46 +0100 Subject: [PATCH 567/658] use spikeinterface to get unit ids by default, and allows units from multiple streams --- pixels/behaviours/base.py | 88 ++++++--------------------------------- pixels/decorators.py | 2 + pixels/pixels_utils.py | 2 + pixels/stream.py | 3 ++ 4 files changed, 19 insertions(+), 76 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index fae839e..00cddcb 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1593,9 +1593,8 @@ def get_spike_times(self, units, remapped=False, use_si=False): def select_units( - self, group='good', min_depth=0, max_depth=None, min_spike_width=None, - unit_kwargs=None, max_spike_width=None, uncurated=False, name=None, - use_si=False, + self, min_depth=0, max_depth=None, min_spike_width=None, + unit_kwargs=None, max_spike_width=None, name=None, ): """ Select units based on specified criteria. The output of this can be passed to @@ -1603,10 +1602,6 @@ def select_units( Parameters ---------- - group : str, optional - The group to which the units that are wanted are part of. One of: 'group', - 'mua', 'noise' or None. Default is 'good'. - min_depth : int, optional (Only used when getting spike data). The minimum depth that units must be at to be included. Default is 0 i.e. in the brain. @@ -1623,9 +1618,6 @@ def select_units( (Only used when getting spike data). The maximum median spike width that units must have to be included. Default is None i.e. no maximum. - uncurated : bool, optional - Use uncurated units. Default: False. - name : str, optional Give this selection of units a name. This allows the list of units to be represented as a string, which enables caching. Future calls to cacheable @@ -1634,9 +1626,11 @@ def select_units( is the same between uses of the same name. """ - if use_si: - # NOTE: only deal with one stream for now - stream_files = self.files["pixels"]["imec0.ap"] + selected_units = SelectedUnits(name=name) if name is not None\ + else SelectedUnits() + + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser temp_sa = si.load(sa_dir) @@ -1659,14 +1653,9 @@ def select_units( # get units unit_ids = sa.unit_ids - # init units class - selected_units = SelectedUnits() - if name is not None: - selected_units.name = name - if name == "all": - selected_units.extend(unit_ids) - return selected_units + selected_units[stream_id] = unit_ids + continue # get shank id for units shank_ids = sa.sorting.get_property("group") @@ -1687,7 +1676,7 @@ def select_units( (shank_ids == shank_id) ] # add to list - selected_units.extend(in_range) + selected_units.extend(stream_id, in_range) else: # if there is only one shank # find units @@ -1695,62 +1684,9 @@ def select_units( (depths >= min_depth) & (depths < max_depth) ] # add to list - selected_units.extend(in_range) - - return selected_units - - else: - cluster_info = self.get_cluster_info() - selected_units = SelectedUnits() - if name is not None: - selected_units.name = name - - if min_depth is not None or max_depth is not None: - probe_depths = self.get_probe_depth() - - if min_spike_width == 0: - min_spike_width = None - if min_spike_width is not None or max_spike_width is not None: - widths = self.get_spike_widths() - else: - widths = None - - for stream_num, info in enumerate(cluster_info): - # TODO jun 12 2024 skip stream 1 for now - if stream_num > 0: - continue + selected_units.extend(stream_id, in_range) - id_key = 'id' if 'id' in info else 'cluster_id' - grouping = 'KSLabel' if uncurated else 'group' - - for unit in info[id_key]: - unit_info = info.loc[info[id_key] == unit].iloc[0].to_dict() - - # we only want units that are in the specified group - if not group or unit_info[grouping] == group: - - # and that are within the specified depth range - if min_depth is not None: - if probe_depths[stream_num] - unit_info['depth'] <= min_depth: - continue - if max_depth is not None: - if probe_depths[stream_num] - unit_info['depth'] > max_depth: - continue - - # and that have the specified median spike widths - if widths is not None: - width = widths[widths['unit'] == unit]['median_ms'] - assert len(width.values) == 1 - if min_spike_width is not None: - if width.values[0] < min_spike_width: - continue - if max_spike_width is not None: - if width.values[0] > max_spike_width: - continue - - selected_units.append(unit) - - return selected_units + return selected_units def _get_neuro_raw(self, kind): raw = [] diff --git a/pixels/decorators.py b/pixels/decorators.py index 7e9a654..90d80a7 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -1,3 +1,5 @@ +# annotations not evaluated at runtime +from __future__ import annotations import numpy as np import pandas as pd from tables import HDF5ExtError diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9df254c..b4304a1 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1,6 +1,8 @@ """ This module provides utilities for pixels data. """ +# annotations not evaluated at runtime +from __future__ import annotations import multiprocessing as mp import json diff --git a/pixels/stream.py b/pixels/stream.py index ec7d438..2d68740 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1,3 +1,6 @@ +# annotations not evaluated at runtime +from __future__ import annotations + import gc from shutil import copyfile From d8a24cd8146da1ed224d2c31c6b87ffd03ecf2fb Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:12:29 +0100 Subject: [PATCH 568/658] use label and event to get spike chance --- pixels/behaviours/base.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 00cddcb..2934732 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2576,6 +2576,53 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): return chance_psd + def get_spike_chance(self, label, event, sigma, end_event): + """ + Get trials aligned to an event. This finds all instances of label in the action + labels - these are the start times of the trials. Then this finds the first + instance of event on or after these start times of each trial. Then it cuts out + a period around each of these events covering all units, rearranges this data + into a MultiIndex DataFrame and returns it. + + Parameters + ---------- + label : int + An action label value to specify which trial types are desired. + + event : int + An event type value to specify which event to align the trials to. + + units : list of lists of ints, optional + The output from self.select_units, used to only apply this method to a + selection of units. + + end_event : int + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. + """ + units = self.select_units(name="all") + streams = self.files["pixels"] + output = {} + + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + output[stream_id] = stream.get_spike_chance( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + + assert 0 + return output + + def save_spike_chance(self, spiked, sigma, sample_rate): # TODO apr 21 2025: # do we put this func here or in stream.py??? From e9be668c2488d8fd4abae0786ae4546878b54421 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:13:17 +0100 Subject: [PATCH 569/658] write chance into zarr --- pixels/pixels_utils.py | 143 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index b4304a1..fa2e4ce 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -816,6 +816,149 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, return None +def _worker_write_repeat(i, zarr_path, sigma, sample_rate): + # Child process re-opens the store to avoid pickling big arrays + store = zarr.DirectoryStore(str(zarr_path)) + root = zarr.open_group(store=store, mode="a") + + # read spiked data + spiked = _read_df_from_zarr(root, "spiked") + + # get permuted data + c_spiked, c_fr = _permute_spikes_n_convolve_fr(spiked[:], sigma, sample_rate) + + # Write the i-th slice along last axis + root["chance_spiked"][..., i] = c_spiked + root["chance_fr"][..., i] = c_fr + + logging.info(f"\nRepeat {i} finished.") + + return None + + +def save_spike_chance_zarr( + zarr_path, + spiked: np.ndarray, + sigma: float, + sample_rate: float, + repeats: int = 100, + positions=None, + meta: dict | None = None, +): + """ + Create a Zarr store at `zarr_path` with datasets: + - spiked: base spiked array (read-only reference) + - chance_spiked: base_shape + (repeats,), int16 + - chance_fr: base_shape + (repeats,), float32 + - positions: optional small array (or vector), stored if provided + Then fill each repeat slice in parallel processes. + + This function is idempotent: if the target datasets exist and match shape, it skips creation. + """ + n_workers = 2 ** (mp.cpu_count().bit_length() - 2) + + zarr_path = Path(zarr_path) + zarr_path.parent.mkdir(parents=True, exist_ok=True) + + base_shape = spiked.shape + d_shape = base_shape + (repeats,) + + chunks = tuple(min(s, 1024) for s in base_shape) + (1,) + + store = zarr.DirectoryStore(str(zarr_path)) + root = zarr.group(store=store, overwrite=not zarr_path.exists()) + + # Metadata + root.attrs["kind"] = "spike_chance" + root.attrs["sigma"] = float(sigma) + root.attrs["sample_rate"] = float(sample_rate) + root.attrs["repeats"] = int(repeats) + if meta: + for k, v in meta.items(): + root.attrs[k] = v + + logging.info(f"\n> Creating zarr dataset.") + # Base source so workers can read it without pickling + if "spiked" in root: + del root["spiked"] + if isinstance(spiked, pd.DataFrame): + _write_df_as_zarr( + root, + spiked, + group_name="spiked", + compressor=compressor, + ) + else: + root.create_dataset( + "spiked", + data=spiked, + chunks=chunks[:-1], + compressor=compressor, + ) + + # Outputs + if "chance_spiked" in root\ + and tuple(root["chance_spiked"].shape) != d_shape: + del root["chance_spiked"] + if "chance_fr" in root and tuple(root["chance_fr"].shape) != d_shape: + del root["chance_fr"] + + if "chance_spiked" not in root: + root.create_dataset( + "chance_spiked", + shape=d_shape, + dtype="int16", + chunks=chunks, + compressor=compressor, + ) + if "chance_fr" not in root: + root.create_dataset( + "chance_fr", + shape=d_shape, + dtype="float32", + chunks=chunks, + compressor=compressor, + ) + + # save positions + if positions is not None: + if "positions" in root: + del root["positions"] + + if isinstance(positions, pd.DataFrame): + _write_df_as_zarr( + root, + positions, + group_name="positions", + compressor=compressor, + ) + else: + root.create_dataset( + "positions", + data=positions, + chunks=True, + compressor=compressor, + ) + + logging.info(f"\n> Starting process pool.") + # Parallel fill: each worker writes a distinct final-axis slice + with ProcessPoolExecutor(max_workers=n_workers) as ex: + futures = [ + ex.submit( + _worker_write_repeat, + i, + zarr_path, + sigma, + sample_rate, + ) for i in range(repeats) + ] + for f in as_completed(futures): + f.result() # raise on error + + # Done. The decorator will open and return the Zarr content. + return None + + def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, fr_df_path, sigma, sample_rate, repeats=100, spiked=None, spiked_shape=None, concat_spiked_path=None): From 2dee6709086d9ebe5c34df96809197996eefacd0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:14:05 +0100 Subject: [PATCH 570/658] import stuff up front --- pixels/pixels_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index fa2e4ce..5e5d9d7 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -4,7 +4,13 @@ # annotations not evaluated at runtime from __future__ import annotations import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor, as_completed import json +from pathlib import Path +import zarr + +import xarray as xr +from numcodecs import Blosc, VLenUTF8 import numpy as np import pandas as pd @@ -979,7 +985,6 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, sigma, sample_rate, """ Implementation of saving chance level spike data. """ - import concurrent.futures # save spiked to memmap if not yet # TODO apr 9 2025: if i have temp_spiked, how to get its shape? do i need @@ -1033,7 +1038,7 @@ def _save_spike_chance(spiked_memmap_path, fr_memmap_path, sigma, sample_rate, del chance_spiked, chance_fr # Set up the process pool to run the worker in parallel. - with concurrent.futures.ProcessPoolExecutor() as executor: + with ProcessPoolExecutor() as executor: # Submit jobs for each repeat. futures = [] for i in range(repeats): @@ -1627,7 +1632,6 @@ def save_chance_psd(sample_rate, positions, paths):#chance_data, idx, cols): Implementation of saving chance level spike data. """ #import concurrent.futures - from concurrent.futures import ProcessPoolExecutor, as_completed from vision_in_darkness.constants import PRE_DARK_LEN, landmarks # Set up the process pool to run the worker in parallel. From bb4dee278cc00f292740b236f09c1f920bc9ea07 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:14:16 +0100 Subject: [PATCH 571/658] write and read dataframe to/from zarr --- pixels/pixels_utils.py | 99 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 5e5d9d7..fd73de9 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1686,3 +1686,102 @@ def notch_freq(rec, freq, bw=4.0): ) return notched + + +def _write_df_as_zarr( + root, # zarr.hierarchy.Group + df: pd.DataFrame, + group_name: str = "positions", + *, + compressor=None, +): + # Remove any existing node (array or group) with this name + if group_name in root: + del root[group_name] + + ds = xr.Dataset.from_dataframe(df) + + # Find row dimension (the one matching len(df)) + row_dims = [d for d, n in ds.sizes.items() if n == len(df)] + row_dim = row_dims[0] if row_dims else "index" + + # Mark how to rebuild (Multi)Index on read + #ds.attrs["__via"] = "pandas_xarray_df" + #if isinstance(df.index, pd.MultiIndex): + # ds = ds.reset_index(row_dim) + # ds.attrs["__pd_mi_dim__"] = row_dim + # ds.attrs["__pd_mi_levels__"] = [ + # n if n is not None else f"level_{i}" for i, n in enumerate(df.index.names) + # ] + #else: + # ds.attrs["__pd_index_dim__"] = row_dim + if isinstance(df.columns, pd.MultiIndex): + # Encode as a 3D DataArray: (row_dim, *df.columns.names) + s = df.stack(df.columns.names) # Series with index (row_dim, level0, level1, ...) + da = xr.DataArray.from_series(s) # dims are names from the Series index + da = da.rename("values") + ds = da.to_dataset() + # Mark for round-trip + ds.attrs["__via"] = "pd_df_mi_cols" + ds.attrs["__row_dim__"] = row_dim + ds.attrs["__col_levels__"] = list(df.columns.names) + else: + # Single-level columns: still avoid per-column variables by using a 2D DataArray (row_dim, col_name) + col_dim = df.columns.name or "columns" + s = df.stack(col_dim) # Series with index (row_dim, col_dim) + da = xr.DataArray.from_series(s).rename("values") + ds = da.to_dataset() + ds.attrs["__via"] = "pd_df_cols" + ds.attrs["__row_dim__"] = row_dim + ds.attrs["__col_levels__"] = [col_dim] + + # chunking + chunks = 1024 + ds = ds.chunk({row_dim: chunks}) + + # Encoding: compressor and variable-length UTF-8 for object columns + #encoding = {v: {"compressor": compressor} for v in ds.data_vars} + #for v, da in ds.data_vars.items(): + # if da.dtype == object and VLenUTF8 is not None: + # encoding[v]["object_codec"] = VLenUTF8() + encoding = {"values": {"compressor": compressor}} + if ds["values"].dtype == object and VLenUTF8 is not None: + encoding["values"]["object_codec"] = VLenUTF8() + + # Write into a subgroup under the same store + ds.to_zarr( + store=root.store, + group=group_name, + mode="w", + encoding=encoding, + ) + + logging.info(f"\n> DataFrame {group_name} written to zarr.") + + return None + + +def _read_df_from_zarr(root, group_name: str) -> pd.DataFrame: + ds = xr.open_zarr( + store=root.store, + group=group_name, + consolidated=False, + chunks="auto", + ) + da = ds["values"] + row_dim = ds.attrs.get("__row_dim__", da.dims[0]) + col_levels = ds.attrs.get("__col_levels__", list(da.dims[1:])) + + # Series with MultiIndex index (row_dim, *col_levels) + s = da.to_series() + + # If there are column dims, unstack them back to columns + if col_levels: + df = s.unstack(col_levels) + else: + # No column dims -> a single column DataFrame + df = s.to_frame(name="values") + + df.index.name = row_dim + + return df From 620c86b3933ac57a81b069e32f570d8c3ed89ac1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:14:36 +0100 Subject: [PATCH 572/658] get chance for experiment --- pixels/experiment.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pixels/experiment.py b/pixels/experiment.py index d2f2757..bfd6453 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -606,3 +606,12 @@ def sync_vr(self, vr): session.sync_vr(vr_session) return None + + + def get_spike_chance(self, *args, **kwargs): + chance = {} + for i, session in enumerate(self.sessions): + name = session.name + chance[name] = session.get_spike_chance(*args, **kwargs) + + return df From 5140fbcb5b5ab4577027911a470cc63fdf44f9cd Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:14:58 +0100 Subject: [PATCH 573/658] add zarr backend for cache --- pixels/decorators.py | 542 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 542 insertions(+) diff --git a/pixels/decorators.py b/pixels/decorators.py index 90d80a7..ea00177 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -1,5 +1,546 @@ # annotations not evaluated at runtime from __future__ import annotations + +import shutil +from pathlib import Path +from functools import wraps +from typing import Any + +import numpy as np +import pandas as pd +from tables import HDF5ExtError +try: + import zarr + from numcodecs import Blosc, VLenUTF8 +except Exception: + zarr = None + Blosc = None + VLenUTF8 = None + +try: + import xarray as xr +except Exception: + xr = None + +from pixels.configs import * +from pixels import ioutils +from pixels.error import PixelsError +from pixels.units import SelectedUnits + + +def _safe_key(s: str) -> str: + return str(s).replace("/", "_").replace(".", "_") + + +def _make_default_compressor() -> Any: + if Blosc is None: + return None + return Blosc(cname="zstd", clevel=5, shuffle=Blosc.BITSHUFFLE) + +# --------------------------- +# xarray <-> DataFrame helpers +# --------------------------- + +def _df_to_zarr_via_xarray( + df: pd.DataFrame, + *, + path: Path | None = None, + store: "zarr.storage.Store" | None = None, + group: str | None = None, + dim_name: str | None = None, + chunks: int | dict | None = None, + compressor=None, + mode: str = "w", +) -> None: + """ + Write DataFrame (supports MultiIndex) to a Zarr store/group via xarray. + If df.index is MultiIndex, we reset it into coordinate variables and record attrs + so we can reconstruct on read. + + Provide either path (DirectoryStore path) or (store, group). + """ + if xr is None or zarr is None: + raise ImportError( + "xarray/zarr not installed. pip install xarray zarr numcodecs" + ) + + ds = xr.Dataset.from_dataframe(df) + + # Find row dimension (the one matching len(df)) + row_dims = [d for d, n in ds.sizes.items() if n == len(df)] + row_dim = row_dims[0] if row_dims else "index" + + # Rename the row dimension if requested + if dim_name and row_dim != dim_name: + ds = ds.rename({row_dim: dim_name}) + row_dim = dim_name + + # Record how to reconstruct on read + ds.attrs["__via"] = "pandas_xarray_df" + if isinstance(df.index, pd.MultiIndex): + ds = ds.reset_index(row_dim) + ds.attrs["__pd_mi_dim__"] = row_dim + ds.attrs["__pd_mi_levels__"] = [ + n if n is not None else f"level_{i}"\ + for i, n in enumerate(df.index.names) + ] + else: + ds.attrs["__pd_index_dim__"] = row_dim + + # Chunking + if chunks is not None: + if isinstance(chunks, int): + ds = ds.chunk({row_dim: chunks}) + elif isinstance(chunks, dict): + ds = ds.chunk(chunks) + + # Compression and string handling + if compressor is None: + compressor = _make_default_compressor() + encoding = {var: {"compressor": compressor} for var in ds.data_vars} + for var, da in ds.data_vars.items(): + if da.dtype == object and VLenUTF8 is not None: + encoding[var]["object_codec"] = VLenUTF8() + + # Write + if path is not None: + ds.to_zarr(str(path), mode=mode, encoding=encoding) + try: + zarr.consolidate_metadata(str(path)) + except Exception: + pass + else: + assert store is not None + ds.to_zarr(store=store, group=group or "", mode=mode, encoding=encoding) + # consolidate requires a path; skipping here since we're inside a shared + # store + + +def _df_from_zarr_via_xarray( + *, + path: Path | None = None, + store: "zarr.storage.Store" | None = None, + group: str | None = None, +) -> pd.DataFrame: + """ + Read a DataFrame written by _df_to_zarr_via_xarray and reconstruct + MultiIndex if attrs exist. + Provide either path or (store, group). + """ + if xr is None or zarr is None: + raise ImportError( + "xarray/zarr not installed. pip install xarray zarr numcodecs" + ) + + if path is not None: + ds = xr.open_zarr(str(path), consolidated=True, chunks="auto") + else: + ds = xr.open_zarr( + store=store, + group=group or "", + consolidated=False, + chunks="auto", + ) + + # Reconstruct MI if recorded + mi_dim = ds.attrs.get("__pd_mi_dim__") + mi_levels = ds.attrs.get("__pd_mi_levels__") + if mi_dim and mi_levels: + ds = ds.set_index({mi_dim: mi_levels}) + + df = ds.to_dataframe() + return df + + +# ----------------------------------- +# Zarr read/write for arrays and dicts +# ----------------------------------- + +def _normalise_1d_chunks(chunks: Any, n: int) -> Any: + if chunks is None: + return None + if isinstance(chunks, int): + return (min(chunks, n),) + if isinstance(chunks, (tuple, list)): + if len(chunks) == 0: + return None + return (min(int(chunks[0]), n),) + return None + + +def _write_arrays_dicts_to_zarr( + root_path: Path, + obj: Any, + *, + chunks: Any = None, + compressor: Any = None, + overwrite: bool = False, +) -> None: + """ + Write ndarray or dict/nested-dict of ndarrays/DataFrames into a Zarr + directory. + DataFrames inside dicts are written via xarray into corresponding groups. + Top-level pure DataFrame should be written with _df_to_zarr_via_xarray + instead. + """ + if zarr is None: + raise ImportError( + "zarr/numcodecs not installed. pip install zarr numcodecs" + ) + + store = zarr.DirectoryStore(str(root_path)) + if overwrite and root_path.exists(): + shutil.rmtree(root_path) + root = zarr.group( + store=store, + overwrite=overwrite or (not root_path.exists()) + ) + + def write_into(prefix: str, value: Any): + # prefix: group path relative to root ("" for root) + if isinstance(value, np.ndarray): + g = zarr.open_group(store=store, path=prefix, mode="a") + name = "array" if prefix == "" else prefix.split("/")[-1] + # In groups, datasets live as siblings; for arrays we use the + # current group's name + # Better: use a fixed name for standalone arrays in a group + # Here we store as "values" for groups, or "array" at root if + # top-level ndarray + ds_name = "array" if prefix == "" else "values" + if ds_name in g: + del g[ds_name] + ds_chunks = chunks + if isinstance(chunks, int) and value.ndim == 1: + ds_chunks = _normalise_1d_chunks(chunks, len(value)) + g.create_dataset( + name=ds_name, + data=value, + chunks=ds_chunks, + compressor=compressor, + ) + + elif isinstance(value, pd.DataFrame): + # Write DF via xarray under this group path + _df_to_zarr_via_xarray( + value, + store=store, + group=prefix or "", + dim_name=value.index.name or "index", + chunks=chunks, + compressor=compressor, + mode="w", + ) + + elif isinstance(value, dict): + # Recurse for each item + for k, v in value.items(): + key = _safe_key(k) + next_prefix = f"{prefix}/{key}" if prefix else key + # Ensure group exists + zarr.open_group(store=store, path=prefix, mode="a") + write_into(next_prefix, v) + + else: + raise TypeError( + "Zarr backend supports ndarray, DataFrame, or dicts of them. " + f"Got: {type(value)} at group '{prefix or '/'}'" + ) + + if isinstance(obj, dict): + for k, v in obj.items(): + write_into(_safe_key(k), v) + elif isinstance(obj, np.ndarray): + write_into("", obj) + else: + raise TypeError( + "Top-level object must be ndarray or dict for this writer." + ) + + # Optional consolidation (best when writing via path) + try: + zarr.consolidate_metadata(str(root_path)) + except Exception: + pass + + +def _read_zarr_generic(root_path: Path) -> Any: + """ + Read back what _write_arrays_dicts_to_zarr wrote and also detect DataFrame + groups written via xarray (by checking group attrs['__via']). + + Returns: + - DataFrame (if top-level was DF written via xarray) + - zarr.Array (if top-level ndarray) + - dict tree mixing DataFrames and zarr.Arrays + """ + if zarr is None: + raise ImportError( + "zarr/numcodecs not installed. pip install zarr numcodecs" + ) + + store = zarr.DirectoryStore(str(root_path)) + if not root_path.exists(): + return None + root = zarr.open_group(store=store, mode="r") + + # If top-level was written via xarray as a DataFrame + if root.attrs.get("__via") == "pandas_xarray_df" and xr is not None: + return _df_from_zarr_via_xarray(store=store, group="") + + # If top-level is a single array written at root + if "array" in root and isinstance(root["array"], zarr.Array): + return root["array"] + + def read_from_group(prefix: str) -> Any: + g = zarr.open_group(store=store, path=prefix, mode="r") + # DataFrame group? + if g.attrs.get("__via") == "pandas_xarray_df" and xr is not None: + return _df_from_zarr_via_xarray(store=store, group=prefix or "") + + out: dict[str, Any] = {} + for name, node in g.items(): + full = f"{prefix}/{name}" if prefix else name + if isinstance(node, zarr.Array): + # Arrays inside groups are stored as "values" + if name == "values" and prefix: + out[prefix.split("/")[-1]] = node + else: + out[name] = node + elif isinstance(node, zarr.hierarchy.Group): + # Recurse + res = read_from_group(full) + # If res is a DF and 'name' was only a container, store under + # that key + out[name] = res + return out + + return read_from_group("") + + +# ----------------------- +# Decorator with Zarr +# ----------------------- + +def cacheable( + _func=None, + *, + cache_format: str | None = None, + zarr_chunks: Any = None, + zarr_compressor: Any = None, + zarr_dim_name: str | None = None, +): + """ + Decorator factory for caching. + + Usage: + @cacheable # default HDF5 + def f(...): ... + + @cacheable(cache_format='zarr') # Zarr default for this method + def g(...): ... + + Backend precedence: + per-call kwarg cache_format > per-method (decorator args) + > instance default self._cache_format > 'hdf5' + + Zarr options: + - zarr_chunks: int (rows per chunk for DataFrame; 1D arrays) or tuple/dict for arrays/xarray + - zarr_compressor: a numcodecs compressor (e.g., Blosc(...)); default zstd+bitshuffle + - zarr_dim_name: optional row-dimension name when writing DataFrame via xarray + """ + default_backend = cache_format or "hdf5" + + def decorator(method): + @wraps(method) + def wrapper(*args, **kwargs): + name = kwargs.pop("name", None) + + # Per-call overrides + per_call_backend = kwargs.pop("cache_format", None) + per_call_chunks = kwargs.pop("zarr_chunks", zarr_chunks) + per_call_compressor = kwargs.pop("zarr_compressor", zarr_compressor) + per_call_dim_name = kwargs.pop("zarr_dim_name", zarr_dim_name) + + inst = args[0] + + # Units gating (unchanged) + if "units" in kwargs: + units = kwargs["units"] + if not isinstance(units, SelectedUnits) or not hasattr(units, "name"): + return method(*args, **kwargs) + + if not getattr(inst, "_use_cache", True): + return method(*args, **kwargs) + + # Build key parts (unchanged) + self_, *as_list = list(args) + list(kwargs.values()) + arrays = [i for i, arg in enumerate(as_list) if isinstance(arg, np.ndarray)] + if arrays: + if name is None: + raise PixelsError( + "Cacheing methods when passing arrays requires also " + "passing name='something'" + ) + for i in arrays: + as_list[i] = name + + key_parts = [method.__name__] + [ + str(i.name) if hasattr(i, "name") else str(i) for i in as_list + ] + base = inst.cache / ("_".join(key_parts) + f"_{inst.stream_id}") + + backend = per_call_backend\ + or getattr(inst, "_cache_format", None)\ + or default_backend + + # HDF5 backend + if backend == "hdf5": + cache_path = base.with_suffix(".h5") + if cache_path.exists() and inst._use_cache != "overwrite": + try: + df = ioutils.read_hdf5(cache_path) + logging.info(f"\n> Cache loaded from {cache_path}.") + except HDF5ExtError: + df = None + logging.info("\n> df is None, cache does not exist.") + except (KeyError, ValueError): + df = {} + with pd.HDFStore(cache_path, "r") as store: + for key in store.keys(): + parts = key.lstrip("/").split("/") + if len(parts) == 1: + df[parts[0]] = store[key] + elif len(parts) == 2: + stream, nm = parts[0], "/".join(parts[1:]) + df.setdefault(stream, {})[nm] = store[key] + logging.info(f"\n> Cache loaded from {cache_path}.") + else: + df = method(*args, **kwargs) + cache_path.parent.mkdir(parents=True, exist_ok=True) + if df is None: + cache_path.touch() + logging.info("\n> df is None, cache will exist but be empty.") + else: + if isinstance(df, dict): + if ioutils.is_nested_dict(df): + for probe_id, nested_dict in df.items(): + for nm, values in nested_dict.items(): + ioutils.write_hdf5( + path=cache_path, + df=values, + key=f"/{probe_id}/{nm}", + mode="a", + ) + else: + for nm, values in df.items(): + ioutils.write_hdf5( + path=cache_path, df=values, key=nm, mode="a" + ) + else: + ioutils.write_hdf5(cache_path, df) + return df + + # Zarr backend (with DataFrame via xarray, MultiIndex supported) + if backend == "zarr": + if zarr is None: + raise ImportError( + "cache_format='zarr' requires zarr. pip install zarr numcodecs xarray" + ) + zarr_path = base.with_suffix(".zarr") + can_read = zarr_path.exists() and inst._use_cache != "overwrite" + + if can_read: + try: + obj = _read_zarr_generic(zarr_path) + logging.info(f"\n> Zarr cache loaded from {zarr_path}.") + return obj + except Exception as e: + logging.info(f"\n> Failed to read Zarr cache ({e}); recomputing.") + + # inject reserved kwargs so the method can write directly to + # store + kwargs["_zarr_out"] = zarr_path + + # Compute fresh + result = method(*args, **kwargs) + if result is None: + # Method handled writing itself; read and return + obj = _read_zarr_generic(zarr_path) + logging.info(f"\n> Zarr cache written to {zarr_path}.") + return obj + + # Overwrite + if inst._use_cache == "overwrite" and zarr_path.exists(): + shutil.rmtree(zarr_path) + + compressor = per_call_compressor or _make_default_compressor() + + # DataFrame via xarray (works for MultiIndex) + if isinstance(result, pd.DataFrame): + _df_to_zarr_via_xarray( + result, + path=zarr_path, + dim_name=per_call_dim_name or result.index.name or "index", + chunks=per_call_chunks, + compressor=compressor, + mode="w", + ) + logging.info( + f"\n> Zarr cache (DataFrame via xarray) written to {zarr_path}." + ) + return result + + # Dict/nested-dict of arrays or DataFrames + if isinstance(result, dict) or isinstance(result, np.ndarray): + _write_arrays_dicts_to_zarr( + zarr_path, + result, + chunks=per_call_chunks, + compressor=compressor, + overwrite=True, + ) + logging.info(f"\n> Zarr cache written to {zarr_path}.") + return result + + # Fallback for unsupported types: write HDF5 like before + logging.warning( + "cache_format='zarr' requested but result type " + "not supported for Zarr; falling back to HDF5." + ) + h5_fallback = base.with_suffix(".h5") + if isinstance(result, dict): + if ioutils.is_nested_dict(result): + for probe_id, nested_dict in result.items(): + for nm, values in nested_dict.items(): + ioutils.write_hdf5( + path=h5_fallback, + df=values, + key=f"/{probe_id}/{nm}", + mode="a", + ) + else: + for nm, values in result.items(): + ioutils.write_hdf5( + path=h5_fallback, + df=values, + key=str(nm), + mode="a", + ) + else: + ioutils.write_hdf5(h5_fallback, result) + logging.info(f"\n> Cache written to {h5_fallback} (fallback).") + return result + + raise ValueError(f"Unknown cache_format/backend: {backend}") + + return wrapper + + if _func is None: + return decorator + else: + return decorator(_func) + +''' import numpy as np import pandas as pd from tables import HDF5ExtError @@ -103,3 +644,4 @@ def wrapper(*args, **kwargs): ioutils.write_hdf5(cache_path, df) return df return wrapper +''' From 309c9312dd529ad0015e471c86112a698789203c Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:15:19 +0100 Subject: [PATCH 574/658] add default compressor for generic zarr --- pixels/configs.py | 8 ++++++++ pixels/pixels_utils.py | 1 + 2 files changed, 9 insertions(+) diff --git a/pixels/configs.py b/pixels/configs.py index adaa936..f0006c8 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -1,6 +1,7 @@ import logging from wavpack_numcodecs import WavPack +from numcodecs import Blosc import spikeinterface as si # Configure logging to include a timestamp with seconds @@ -34,6 +35,13 @@ bps=None, # lossless ) +# use blosc compressor for generic zarr +compressor = Blosc( + cname="zstd", + clevel=5, + shuffle=Blosc.BITSHUFFLE, +) + # kilosort 4 singularity image names ks4_0_30_image_name = "si103.0_ks4-0-30_with_wavpack.sif" #ks4_0_30_image_name = "si102.3_ks4-0-30_with_wavpack.sif" diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index fd73de9..969ab2f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -3,6 +3,7 @@ """ # annotations not evaluated at runtime from __future__ import annotations + import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor, as_completed import json From 6d683ed5b5753a9cf001395b711e76597fb0fae8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:17:15 +0100 Subject: [PATCH 575/658] break down align trials parts to reuse code --- pixels/stream.py | 322 ++++++++++++++++++++++++++++------------------- 1 file changed, 193 insertions(+), 129 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 2d68740..91b6d12 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -53,73 +53,10 @@ def load_raw_ap(self): return self.files["si_rec"] - @cacheable - def align_trials(self, units, data, label, event, sigma, end_event): - """ - Align pixels data to behaviour trials. - - params - === - units : list of lists of ints, optional - The output from self.select_units, used to only apply this method to a - selection of units. - - data : str, optional - The data type to align. - - label : int - An action label value to specify which trial types are desired. - - event : int - An event type value to specify which event to align the trials to. - - sigma : int, optional - Time in milliseconds of sigma of gaussian kernel to use when - aligning firing rates. - - end_event : int | None - For VR behaviour, when aligning to the whole trial, this param is - the end event to align to. - - return - === - df, output from individual functions according to data type. - """ - - if "spike_trial" in data: - logging.info( - f"\n> Aligning spike times and spike rate of {units} units to " - f"<{label.name}> trials." - ) - return self._get_aligned_trials( - label, event, units=units, sigma=sigma, end_event=end_event, - ) - elif "spike_event" in data: - logging.info( - f"\n> Aligning spike times and spike rate of {units} units to " - f"{event.name} event in <{label.name}> trials." - ) - return self._get_aligned_events( - label, event, units=units, sigma=sigma, - ) - else: - raise NotImplementedError( - "> Other types of alignment are not implemented." - ) - - - def _get_aligned_trials( - self, label, event, units=None, sigma=None, end_event=None, - ): + def _map_trials(self, label, event, end_event=None): # get synched pixels stream with vr and action labels synched_vr, action_labels = self.get_synched_vr() - # get positions of all trials - all_pos = synched_vr.position_in_tunnel - - # get spike times - spikes = self.get_spike_times(units) - # get action and event label file outcomes = action_labels["outcome"] events = action_labels["events"] @@ -141,30 +78,50 @@ def _get_aligned_trials( selected_ends = trials[np.where(np.isin(trials, ends))[0]] end_t = timestamps[selected_ends] + # use original trial id as trial index + trial_ids = pd.Index( + synched_vr.iloc[selected_starts].trial_count.unique() + ) + + return trials, events, selected_starts, start_t, end_t, trial_ids + + + #@cacheable(cache_format="zarr") + def _get_vr_positions(self, label, event, end_event): + logging.info( + f"\n> Getting {self.session.name} {self.stream_id} positions." + ) + + # map trials + (trials, events, selected_starts, + start_t, end_t, trial_ids) = self._map_trials( + label, + event, + end_event, + ) + if selected_starts.size == 0: logging.info(f"\n> No trials found with label {label} and event " f"{event.name}, output will be empty.") return None - # use original trial id as trial index - trial_ids = pd.Index( - synched_vr.iloc[selected_starts].trial_count.unique() - ) + # get synched vr + synched_vr, _ = self.get_synched_vr() + + # get positions of all trials + all_pos = synched_vr.position_in_tunnel + all_pos_val = all_pos.to_numpy() + all_pos_idx = all_pos.index.to_numpy() # map actual starting locations if not "trial_start" in event.name: - all_start_idx = np.where( - np.bitwise_and(events, event.trial_start) - )[0] - start_idx = trials[np.where( - np.isin(trials, all_start_idx) - )[0]] + all_start_idx = np.flatnonzero(events & event.trial_start) + start_idx = trials[np.isin(trials, all_start_idx)] else: start_idx = selected_starts.copy() - start_pos = synched_vr.position_in_tunnel.iloc[ - start_idx - ].values.astype(int) + # get start positions + start_pos = all_pos_val[start_idx].astype(int) # create multiindex with starts cols_with_starts = pd.MultiIndex.from_arrays( @@ -172,6 +129,45 @@ def _get_aligned_trials( names=("start", "trial"), ) + # find start and end position index + trial_start_t = np.searchsorted(all_pos_idx, start_t, side="left") + trial_end_t = np.searchsorted(all_pos_idx, end_t, side="right") + + trials_positions = [ + pd.Series(all_pos_val[s:e]) + for s, e in zip(trial_start_t, trial_end_t) + ] + positions = pd.concat(trials_positions, axis=1) + positions.columns = cols_with_starts + positions = positions.sort_index(axis=1, ascending=[False, True]) + positions.index.name = "time" + + return positions + + + #@cacheable(cache_format="zarr") + def _get_vr_spikes(self, units, label, event, sigma, end_event): + logging.info( + f"\n> Getting {self.session.name} {self.stream_id} spikes." + ) + + # map trials timestamps and index + (_, _, selected_starts, + start_t, end_t, trial_ids) = self._map_trials( + label, + event, + end_event, + ) + + if selected_starts.size == 0: + logging.info(f"\n> No trials found with label {label} and event " + f"{event.name}, output will be empty.") + return None + + # get spike times + spikes = self.get_spike_times(units) + units = units[self.stream_id] + # pad ends with 1 second extra to remove edge effects from # convolution scan_pad = self.BEHAVIOUR_SAMPLE_RATE @@ -193,47 +189,38 @@ def _get_aligned_trials( ] - cursor_duration cursor += samples - output = {} - trials_fr = {} trials_spiked = {} - trials_positions = {} + trials_fr = {} for i, start in enumerate(selected_starts): # select spike times of current trial trial_bool = (rec_spikes >= scan_starts[i])\ & (rec_spikes <= scan_ends[i]) trial = rec_spikes[trial_bool] - # get position bin ids for current trial - trial_pos_bool = (all_pos.index >= start_t[i])\ - & (all_pos.index <= end_t[i]) - trial_pos = all_pos[trial_pos_bool] # initiate binary spike times array for current trial # NOTE: dtype must be float otherwise would get all 0 when passing # gaussian kernel - times = np.zeros((scan_durations[i], len(units))).astype(float) + times = np.zeros((scan_durations[i], len(units)), dtype=float) # use pixels time as spike index idx = np.arange(scan_starts[i], scan_ends[i]) - # make it df, column name being unit id - spiked = pd.DataFrame(times, index=idx, columns=units) - # TODO mar 5 2025: how to separate aligned trial times and chance, - # so that i can use cacheable to get all conditions?????? - for unit in trial: + for j, unit in enumerate(trial): # get spike time for unit u_times = trial[unit].values # drop nan u_times = u_times[~np.isnan(u_times)] # round spike times to use it as index - u_spike_idx = np.round(u_times).astype(int) + u_spike_idx = np.round(u_times).astype(int) - scan_starts[i] # make sure it does not exceed scan duration - if (u_spike_idx >= scan_ends[i]).any(): - beyonds = np.where(u_spike_idx >= scan_ends[i])[0] - u_spike_idx[beyonds] = idx[-1] - # make sure no double counted - u_spike_idx = np.unique(u_spike_idx) + u_spike_idx = u_spike_idx[ + (u_spike_idx >= 0) & (u_spike_idx < scan_durations[i]) + ] + if u_spike_idx.size: + # set spiked to 1 + times[np.unique(u_spike_idx), j] = 1 - # set spiked to 1 - spiked.loc[u_spike_idx, unit] = 1 + # make it df, column name being unit id + spiked = pd.DataFrame(times, index=idx, columns=units) # convolve spike trains into spike rates rates = signal.convolve_spike_trains( @@ -251,18 +238,6 @@ def _get_aligned_trials( trials_fr[trial_ids[i]] = rates spiked.reset_index(inplace=True, drop=True) trials_spiked[trial_ids[i]] = spiked - trial_pos.reset_index(inplace=True, drop=True) - trials_positions[trial_ids[i]] = trial_pos - - # concat trial positions - positions = ioutils.reindex_by_longest( - dfs=trials_positions, - idx_names=["trial", "time"], - level="trial", - return_format="dataframe", - ) - positions.columns = cols_with_starts - positions = positions.sort_index(axis=1, ascending=[False, True]) # get trials vertically stacked spiked stacked_spiked = pd.concat( @@ -271,25 +246,8 @@ def _get_aligned_trials( ) stacked_spiked.index.names = ["trial", "time"] stacked_spiked.columns.names = ["unit"] - - # TODO apr 21 2025: - # save spike chance only if all units are selected, else - # only index into the big chance array and save into zarr - #if units.name == "all" and (label == 725 or 1322): - # self.save_spike_chance( - # stream_files=stream_files, - # spiked=stacked_spiked, - # sigma=sigma, - # ) - #else: - # # access chance data if we only need part of the units - # self.get_spike_chance( - # sample_rate=self.SAMPLE_RATE, - # positions=all_pos, - # sigma=sigma, - # ) - # assert 0 - + + output = {} # get trials horizontally stacked spiked spiked = ioutils.reindex_by_longest( dfs=stacked_spiked, @@ -306,8 +264,114 @@ def _get_aligned_trials( output["spiked"] = spiked output["fr"] = fr + + return stacked_spiked, output + + + @cacheable#(cache_format="zarr") + def align_trials(self, units, data, label, event, sigma, end_event): + """ + Align pixels data to behaviour trials. + + params + === + units : dictionary of lists of ints + The output from self.select_units, used to only apply this method to a + selection of units. + + data : str, optional + The data type to align. + + label : int + An action label value to specify which trial types are desired. + + event : int + An event type value to specify which event to align the trials to. + + sigma : int, optional + Time in milliseconds of sigma of gaussian kernel to use when + aligning firing rates. + + end_event : int | None + For VR behaviour, when aligning to the whole trial, this param is + the end event to align to. + + return + === + df, output from individual functions according to data type. + """ + + if "spike_trial" in data: + logging.info( + f"\n> Aligning spike times and spike rate of {units} units to " + f"<{label.name}> trials." + ) + return self._get_aligned_trials( + label, event, units=units, sigma=sigma, end_event=end_event, + ) + elif "spike_event" in data: + logging.info( + f"\n> Aligning spike times and spike rate of {units} units to " + f"{event.name} event in <{label.name}> trials." + ) + return self._get_aligned_events( + label, event, units=units, sigma=sigma, + ) + else: + raise NotImplementedError( + "> Other types of alignment are not implemented." + ) + + + def _get_aligned_trials( + self, label, event, units=None, sigma=None, end_event=None, + ): + # get positions + positions = self._get_vr_positions(label, event, end_event) + + # get spikes and firing rate + _, output = self._get_vr_spikes( + units, + label, + event, + sigma, + end_event, + ) output["positions"] = positions + if units.name == "all" and (label.name == "light" or "dark"): + self.get_spike_chance( + #stream_files=stream_files, + spiked=stacked_spiked, + sigma=sigma, + ) + #else: + # # access chance data if we only need part of the units + # self.get_spike_chance( + # sample_rate=self.SAMPLE_RATE, + # positions=all_pos, + # sigma=sigma, + # ) + # assert 0 + + # get trials horizontally stacked spiked + #spiked = ioutils.reindex_by_longest( + # dfs=stacked_spiked, + # level="trial", + # return_format="dataframe", + #) + #fr = ioutils.reindex_by_longest( + # dfs=trials_fr, + # level="trial", + # idx_names=["trial", "time"], + # col_names=["unit"], + # return_format="dataframe", + #) + + #output["spiked"] = spiked + #output["fr"] = fr + #output["positions"] = positions + return output From 250391ef06678ecc0e839f314433f6de7a58829e Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:17:40 +0100 Subject: [PATCH 576/658] now need stream_id to access units --- pixels/stream.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 91b6d12..1ea2339 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -384,6 +384,8 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): # get spike times spikes = self.get_spike_times(units) + # now get array unit ids + units = units[self.stream_id] # get action and event label file outcomes = action_labels["outcome"] @@ -584,6 +586,8 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): def get_spike_times(self, units): + units = units[self.stream_id] + # find sorting analyser path sa_path = self.session.find_file(self.files["sorting_analyser"]) # load sorting analyser From b6475231b9c8f4bf84b45579596d0a30377c4bba Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 19 Sep 2025 20:18:43 +0100 Subject: [PATCH 577/658] cache spike chance into zarr --- pixels/stream.py | 45 +++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 1ea2339..b4d9ba4 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1186,8 +1186,16 @@ def get_spatial_psd( return psd_df - def get_spike_chance(self, units, label, event, sigma, end_event): - positions, paths = self._get_chance_args( + @cacheable(cache_format="zarr") + def get_spike_chance(self, units, label, event, sigma, end_event, + # reserved kwargs injected by decorator when cache_format='zarr' + _zarr_out: Path | str | None = None, + ): + # get positions + positions = self._get_vr_positions(label, event, end_event) + + # get spikes + stacked_spikes, _ = self._get_vr_spikes( units, label, event, @@ -1195,26 +1203,31 @@ def get_spike_chance(self, units, label, event, sigma, end_event): end_event, ) - fr_chance, idx, cols = xut.get_spike_chance( + logging.info( + f"\n> Getting {self.session.name} {self.stream_id} spike chance." + ) + + xut.save_spike_chance_zarr( + zarr_path=_zarr_out, + spiked=stacked_spikes, + sigma=sigma, sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + repeats=2, #REPEATS, positions=positions, - **paths, + meta=dict( + label=str(label), + event=str(event), + end_event=str(end_event), + units_name=getattr(units, "name", None), + ), ) + del stacked_spikes, positions + gc.collect() - return positions, fr_chance, idx, cols - + return None - def _get_chance_args(self, units, label, event, sigma, end_event): - trials = self.align_trials( - units=units, # NOTE: ALWAYS the first arg - data="spike_trial", # NOTE: ALWAYS the second arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, # NOTE: ALWAYS the last arg - ) - positions = trials["positions"] + def _get_chance_args(self, label, event, sigma, end_event): probe_id = self.stream_id[:-3] name = self.session.name paths = { From 9baa549de89e755c9989cdb511b69d55f84dd02a Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:46:12 +0100 Subject: [PATCH 578/658] use bool to reduce size --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 969ab2f..1efad34 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -914,7 +914,7 @@ def save_spike_chance_zarr( root.create_dataset( "chance_spiked", shape=d_shape, - dtype="int16", + dtype="bool", chunks=chunks, compressor=compressor, ) From 3982cf1d2e118b878aad2e2ad0e0c22101ed3e2f Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:47:10 +0100 Subject: [PATCH 579/658] put spiked on shared memory to reduce memory load --- pixels/pixels_utils.py | 101 +++++++++++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 28 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 1efad34..c2f9c8e 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -5,6 +5,7 @@ from __future__ import annotations import multiprocessing as mp +from multiprocessing import shared_memory from concurrent.futures import ProcessPoolExecutor, as_completed import json from pathlib import Path @@ -823,22 +824,35 @@ def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, return None -def _worker_write_repeat(i, zarr_path, sigma, sample_rate): - # Child process re-opens the store to avoid pickling big arrays - store = zarr.DirectoryStore(str(zarr_path)) - root = zarr.open_group(store=store, mode="a") +def _worker_write_repeat( + i, zarr_path, sigma, sample_rate, shm_name, shape, dtype_str, +): + # attach to shared memory + shm = shared_memory.SharedMemory(name=shm_name) + try: + spiked = np.ndarray( + shape, + dtype=np.dtype(dtype_str), + buffer=shm.buf, + ) + # avoid accidental in-place mutation + spiked.setflags(write=False) + # get permuted data + c_spiked, c_fr = _permute_spikes_n_convolve_fr(spiked, sigma, sample_rate) - # read spiked data - spiked = _read_df_from_zarr(root, "spiked") + # child process re-opens the store to avoid pickling big arrays + store = zarr.DirectoryStore(zarr_path) + root = zarr.open_group(store=store, mode="a") - # get permuted data - c_spiked, c_fr = _permute_spikes_n_convolve_fr(spiked[:], sigma, sample_rate) + # Write the i-th slice along last axis + root["chance_spiked"][..., i] = c_spiked + root["chance_fr"][..., i] = c_fr + del c_spiked, c_fr + gc.collect() - # Write the i-th slice along last axis - root["chance_spiked"][..., i] = c_spiked - root["chance_fr"][..., i] = c_fr - - logging.info(f"\nRepeat {i} finished.") + logging.info(f"\nRepeat {i} finished.") + finally: + shm.close() return None @@ -867,6 +881,12 @@ def save_spike_chance_zarr( zarr_path = Path(zarr_path) zarr_path.parent.mkdir(parents=True, exist_ok=True) + # Convert to array ONCE for shared memory + if isinstance(spiked, pd.DataFrame): + spike_arr = np.ascontiguousarray(spiked.to_numpy()) + else: + spike_arr = np.ascontiguousarray(np.asarray(spiked)) + base_shape = spiked.shape d_shape = base_shape + (repeats,) @@ -947,28 +967,53 @@ def save_spike_chance_zarr( compressor=compressor, ) - logging.info(f"\n> Starting process pool.") - # Parallel fill: each worker writes a distinct final-axis slice - with ProcessPoolExecutor(max_workers=n_workers) as ex: - futures = [ - ex.submit( - _worker_write_repeat, - i, - zarr_path, - sigma, - sample_rate, - ) for i in range(repeats) - ] - for f in as_completed(futures): - f.result() # raise on error + del positions + gc.collect() + + # Create shared memory once + shm = shared_memory.SharedMemory( + create=True, + size=spike_arr.nbytes, + ) + try: + shm_arr = np.ndarray( + base_shape, + dtype=spike_arr.dtype, + buffer=shm.buf, + ) + shm_arr[...] = spike_arr + + logging.info(f"\n> Starting process pool.") + # Pass only the metadata needed to reconstruct the view + dtype_str = spike_arr.dtype.str # portable dtype spec + + # Parallel fill: each worker writes a distinct final-axis slice + with ProcessPoolExecutor(max_workers=n_workers) as ex: + futures = [ + ex.submit( + _worker_write_repeat, + i, + str(zarr_path), + sigma, + sample_rate, + shm.name, + base_shape, + dtype_str, + ) for i in range(repeats) + ] + for f in as_completed(futures): + f.result() # raise on error + finally: + shm.close() + shm.unlink() - # Done. The decorator will open and return the Zarr content. return None def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, fr_df_path, sigma, sample_rate, repeats=100, spiked=None, spiked_shape=None, concat_spiked_path=None): + assert 0 if fr_df_path.exists(): # save spike chance data if does not exists _save_spike_chance( From 252a35696c721f5ce0bca676559b5ba7b3d7e2ec Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:47:38 +0100 Subject: [PATCH 580/658] delete redundant vars --- pixels/pixels_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c2f9c8e..000b662 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -10,6 +10,7 @@ import json from pathlib import Path import zarr +import gc import xarray as xr from numcodecs import Blosc, VLenUTF8 @@ -922,6 +923,8 @@ def save_spike_chance_zarr( chunks=chunks[:-1], compressor=compressor, ) + del spiked + gc.collect() # Outputs if "chance_spiked" in root\ From 22856529f087405c7e32a2cdf8c8567bb252a0db Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:48:28 +0100 Subject: [PATCH 581/658] define chunks upfront --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 000b662..5599f6d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -891,7 +891,7 @@ def save_spike_chance_zarr( base_shape = spiked.shape d_shape = base_shape + (repeats,) - chunks = tuple(min(s, 1024) for s in base_shape) + (1,) + chunks = tuple(min(s, BIG_CHUNKS) for s in base_shape) + (1,) store = zarr.DirectoryStore(str(zarr_path)) root = zarr.group(store=store, overwrite=not zarr_path.exists()) From 9a77ffb0e3b636d34503638fe463a1bc64ee6aae Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:48:52 +0100 Subject: [PATCH 582/658] fix indentation --- pixels/pixels_utils.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 5599f6d..bf9bba4 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -909,20 +909,21 @@ def save_spike_chance_zarr( # Base source so workers can read it without pickling if "spiked" in root: del root["spiked"] - if isinstance(spiked, pd.DataFrame): - _write_df_as_zarr( - root, - spiked, - group_name="spiked", - compressor=compressor, - ) - else: - root.create_dataset( - "spiked", - data=spiked, - chunks=chunks[:-1], - compressor=compressor, - ) + + if isinstance(spiked, pd.DataFrame): + _write_df_as_zarr( + root, + spiked, + group_name="spiked", + compressor=compressor, + ) + else: + root.create_dataset( + "spiked", + data=spiked, + chunks=chunks[:-1], + compressor=compressor, + ) del spiked gc.collect() From 95a509505bc7c9d966581ef0bccf32888884bc96 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:49:16 +0100 Subject: [PATCH 583/658] make sure group_name is generic string --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index bf9bba4..61e3c95 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1741,7 +1741,7 @@ def notch_freq(rec, freq, bw=4.0): def _write_df_as_zarr( root, # zarr.hierarchy.Group df: pd.DataFrame, - group_name: str = "positions", + group_name: str, *, compressor=None, ): From f87d99310d12fd44bde0c17d4e7b06ff3113fff0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:49:36 +0100 Subject: [PATCH 584/658] add row and column prefix --- pixels/pixels_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 61e3c95..f885707 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1745,6 +1745,9 @@ def _write_df_as_zarr( *, compressor=None, ): + row_prefix = "row" + col_prefix = "col" + # Remove any existing node (array or group) with this name if group_name in root: del root[group_name] From 05b803921ef75383dbbc6907ed0127c039081ac9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:50:24 +0100 Subject: [PATCH 585/658] incorporate row and index multiindex --- pixels/pixels_utils.py | 125 ++++++++++++++++++++++++----------------- 1 file changed, 72 insertions(+), 53 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index f885707..58a3956 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1752,55 +1752,62 @@ def _write_df_as_zarr( if group_name in root: del root[group_name] - ds = xr.Dataset.from_dataframe(df) - - # Find row dimension (the one matching len(df)) - row_dims = [d for d, n in ds.sizes.items() if n == len(df)] - row_dim = row_dims[0] if row_dims else "index" - - # Mark how to rebuild (Multi)Index on read - #ds.attrs["__via"] = "pandas_xarray_df" - #if isinstance(df.index, pd.MultiIndex): - # ds = ds.reset_index(row_dim) - # ds.attrs["__pd_mi_dim__"] = row_dim - # ds.attrs["__pd_mi_levels__"] = [ - # n if n is not None else f"level_{i}" for i, n in enumerate(df.index.names) - # ] - #else: - # ds.attrs["__pd_index_dim__"] = row_dim + # Ensure all index/column level names are defined + if isinstance(df.index, pd.MultiIndex): + row_names = _default_names(list(df.index.names), row_prefix) + else: + row_names = [df.index.name or f"{row_prefix}0"] + if isinstance(df.columns, pd.MultiIndex): - # Encode as a 3D DataArray: (row_dim, *df.columns.names) - s = df.stack(df.columns.names) # Series with index (row_dim, level0, level1, ...) - da = xr.DataArray.from_series(s) # dims are names from the Series index - da = da.rename("values") - ds = da.to_dataset() - # Mark for round-trip - ds.attrs["__via"] = "pd_df_mi_cols" - ds.attrs["__row_dim__"] = row_dim - ds.attrs["__col_levels__"] = list(df.columns.names) + col_names = _default_names(list(df.columns.names), col_prefix) else: - # Single-level columns: still avoid per-column variables by using a 2D DataArray (row_dim, col_name) - col_dim = df.columns.name or "columns" - s = df.stack(col_dim) # Series with index (row_dim, col_dim) - da = xr.DataArray.from_series(s).rename("values") - ds = da.to_dataset() - ds.attrs["__via"] = "pd_df_cols" - ds.attrs["__row_dim__"] = row_dim - ds.attrs["__col_levels__"] = [col_dim] - - # chunking - chunks = 1024 - ds = ds.chunk({row_dim: chunks}) - - # Encoding: compressor and variable-length UTF-8 for object columns - #encoding = {v: {"compressor": compressor} for v in ds.data_vars} - #for v, da in ds.data_vars.items(): - # if da.dtype == object and VLenUTF8 is not None: - # encoding[v]["object_codec"] = VLenUTF8() - encoding = {"values": {"compressor": compressor}} + col_names = [df.columns.name or f"{col_prefix}0"] + + # Stack ALL column levels to move them into the row index; result index levels = row_names + col_names + series = df.stack(col_names, future_stack=True) # Series with MultiIndex index + + # Build DataArray (dims are level names of the Series index, in order) + da = xr.DataArray.from_series(series).rename("values") + ds = da.to_dataset() + + # Mark attrs for round-trip (which dims belong to rows vs columns) + ds.attrs["__via"] = "pd_df_any_mi" + ds.attrs["__row_dims__"] = row_names + ds.attrs["__col_dims__"] = col_names + + # check size to determine chunking + chunking = {} + for name, size in ds.sizes.items(): + if size > BIG_CHUNKS: + chunking[name] = BIG_CHUNKS + else: + chunking[name] = SMALL_CHUNKS + + ds = ds.chunk(chunking) + + # compressor & object codec + encoding = { + "values": { + "compressor": compressor, + "chunks": tuple(chunking.values()), + } + } if ds["values"].dtype == object and VLenUTF8 is not None: encoding["values"]["object_codec"] = VLenUTF8() + # Ensure coords are writable (handle object/string coords) + # If VLenUTF8 is available, set encoding for object coords; otherwise cast + # to str + for cname, coord in ds.coords.items(): + if coord.dtype == object: + if VLenUTF8 is not None: + encoding[cname] = { + "object_codec": VLenUTF8(), + "compressor": compressor, + } + else: + ds = ds.assign_coords({cname: coord.astype(str)}) + # Write into a subgroup under the same store ds.to_zarr( store=root.store, @@ -1822,19 +1829,31 @@ def _read_df_from_zarr(root, group_name: str) -> pd.DataFrame: chunks="auto", ) da = ds["values"] - row_dim = ds.attrs.get("__row_dim__", da.dims[0]) - col_levels = ds.attrs.get("__col_levels__", list(da.dims[1:])) + row_dim = list(ds.attrs.get("__row_dims__") or []) + col_dim = list(ds.attrs.get("__col_dims__") or []) - # Series with MultiIndex index (row_dim, *col_levels) - s = da.to_series() + # Series with MultiIndex index (row_dim, *col_dim) + series = da.to_series() # If there are column dims, unstack them back to columns - if col_levels: - df = s.unstack(col_levels) + if col_dim: + df = series.unstack(col_dim) else: # No column dims -> a single column DataFrame - df = s.to_frame(name="values") - - df.index.name = row_dim + df = series.to_frame(name="values") + + col_name = [df.columns.name] + if not (row_dim == df.index.names): + df.index.set_names(row_dim, inplace=True) + if not (col_dim == col_name): + if isinstance(df.columns, pd.MultiIndex): + df.columns.set_names(col_dim, inplace=True) + else: + df.columns.name = col_dim[0] return df + + +def _default_names(names: list[str | None], prefix: str) -> list[str]: + # Replace None level names with defaults: f"{prefix}{i}" + return [n if n is not None else f"{prefix}{i}" for i, n in enumerate(names)] From 93a8398c1600ec2730e18720d4dbc933079be4f8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:50:57 +0100 Subject: [PATCH 586/658] reduce size --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index b4d9ba4..61d1efd 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -200,7 +200,7 @@ def _get_vr_spikes(self, units, label, event, sigma, end_event): # initiate binary spike times array for current trial # NOTE: dtype must be float otherwise would get all 0 when passing # gaussian kernel - times = np.zeros((scan_durations[i], len(units)), dtype=float) + times = np.zeros((scan_durations[i], len(units)), dtype=np.float32) # use pixels time as spike index idx = np.arange(scan_starts[i], scan_ends[i]) From 9aac4c60af0182d152ea44b414a8822c0560ace8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:51:15 +0100 Subject: [PATCH 587/658] use predefined number of repeats --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 61d1efd..af5cc26 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1212,7 +1212,7 @@ def get_spike_chance(self, units, label, event, sigma, end_event, spiked=stacked_spikes, sigma=sigma, sample_rate=self.BEHAVIOUR_SAMPLE_RATE, - repeats=2, #REPEATS, + repeats=REPEATS, positions=positions, meta=dict( label=str(label), From 7c8bc6dae9d8cd1335e5069811d6ba0c0499f2fd Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:51:33 +0100 Subject: [PATCH 588/658] add chunks --- pixels/behaviours/base.py | 1 - pixels/constants.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 2934732..d949990 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2619,7 +2619,6 @@ def get_spike_chance(self, label, event, sigma, end_event): end_event=end_event, ) - assert 0 return output diff --git a/pixels/constants.py b/pixels/constants.py index 1d5cdf4..d386c9e 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -24,3 +24,7 @@ V1_SPIKE_LATENCY = 500 # ms #60 V1_LFP_LATENCY = 40 # ms LGN_SPIKE_LATENCY = 30 # ms + +# chunking for zarr +SMALL_CHUNKS = 64 +BIG_CHUNKS = 1024 From 65b89fa4c326102af5c91c31ee631c69c70e92c9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:52:21 +0100 Subject: [PATCH 589/658] call it group_name --- pixels/decorators.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index ea00177..3e98859 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -46,9 +46,7 @@ def _df_to_zarr_via_xarray( *, path: Path | None = None, store: "zarr.storage.Store" | None = None, - group: str | None = None, - dim_name: str | None = None, - chunks: int | dict | None = None, + group_name: str | None = None, compressor=None, mode: str = "w", ) -> None: @@ -120,7 +118,7 @@ def _df_from_zarr_via_xarray( *, path: Path | None = None, store: "zarr.storage.Store" | None = None, - group: str | None = None, + group_name: str | None = None, ) -> pd.DataFrame: """ Read a DataFrame written by _df_to_zarr_via_xarray and reconstruct @@ -224,9 +222,7 @@ def write_into(prefix: str, value: Any): _df_to_zarr_via_xarray( value, store=store, - group=prefix or "", - dim_name=value.index.name or "index", - chunks=chunks, + group_name=prefix or "", compressor=compressor, mode="w", ) @@ -284,8 +280,8 @@ def _read_zarr_generic(root_path: Path) -> Any: root = zarr.open_group(store=store, mode="r") # If top-level was written via xarray as a DataFrame - if root.attrs.get("__via") == "pandas_xarray_df" and xr is not None: - return _df_from_zarr_via_xarray(store=store, group="") + if root.attrs.get("__via") == "pd_df_any_mi" and xr is not None: + return _df_from_zarr_via_xarray(store=store, group_name="") # If top-level is a single array written at root if "array" in root and isinstance(root["array"], zarr.Array): @@ -294,8 +290,11 @@ def _read_zarr_generic(root_path: Path) -> Any: def read_from_group(prefix: str) -> Any: g = zarr.open_group(store=store, path=prefix, mode="r") # DataFrame group? - if g.attrs.get("__via") == "pandas_xarray_df" and xr is not None: - return _df_from_zarr_via_xarray(store=store, group=prefix or "") + if g.attrs.get("__via") == "pd_df_any_mi" and xr is not None: + return _df_from_zarr_via_xarray( + store=store, + group_name=prefix or "", + ) out: dict[str, Any] = {} for name, node in g.items(): From cb4e766c4c5713a1cb405b77d0dc68fa7df144d9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:52:34 +0100 Subject: [PATCH 590/658] incorporate row and index multiindex --- pixels/decorators.py | 134 +++++++++++++++++++++++++++++-------------- 1 file changed, 92 insertions(+), 42 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 3e98859..71a6782 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -62,54 +62,86 @@ def _df_to_zarr_via_xarray( "xarray/zarr not installed. pip install xarray zarr numcodecs" ) - ds = xr.Dataset.from_dataframe(df) + row_prefix = "row" + col_prefix = "col" - # Find row dimension (the one matching len(df)) - row_dims = [d for d, n in ds.sizes.items() if n == len(df)] - row_dim = row_dims[0] if row_dims else "index" - - # Rename the row dimension if requested - if dim_name and row_dim != dim_name: - ds = ds.rename({row_dim: dim_name}) - row_dim = dim_name - - # Record how to reconstruct on read - ds.attrs["__via"] = "pandas_xarray_df" + # Ensure all index/column level names are defined if isinstance(df.index, pd.MultiIndex): - ds = ds.reset_index(row_dim) - ds.attrs["__pd_mi_dim__"] = row_dim - ds.attrs["__pd_mi_levels__"] = [ - n if n is not None else f"level_{i}"\ - for i, n in enumerate(df.index.names) - ] + row_names = _default_names(list(df.index.names), row_prefix) + else: + row_names = [df.index.name or f"{row_prefix}0"] + + if isinstance(df.columns, pd.MultiIndex): + col_names = _default_names(list(df.columns.names), col_prefix) else: - ds.attrs["__pd_index_dim__"] = row_dim + col_names = [df.columns.name or f"{col_prefix}0"] + + # Stack ALL column levels to move them into the row index; result index levels = row_names + col_names + series = df.stack(col_names, future_stack=True) # Series with MultiIndex index + + # Build DataArray (dims are level names of the Series index, in order) + da = xr.DataArray.from_series(series).rename("values") + ds = da.to_dataset() - # Chunking - if chunks is not None: - if isinstance(chunks, int): - ds = ds.chunk({row_dim: chunks}) - elif isinstance(chunks, dict): - ds = ds.chunk(chunks) + # Mark attrs for round-trip (which dims belong to rows vs columns) + ds.attrs["__via"] = "pd_df_any_mi" + ds.attrs["__row_dims__"] = row_names + ds.attrs["__col_dims__"] = col_names + + # check size to determine chunking + chunking = {} + for name, size in ds.sizes.items(): + if size > BIG_CHUNKS: + chunking[name] = BIG_CHUNKS + else: + chunking[name] = SMALL_CHUNKS + + ds = ds.chunk(chunking) - # Compression and string handling if compressor is None: compressor = _make_default_compressor() - encoding = {var: {"compressor": compressor} for var in ds.data_vars} - for var, da in ds.data_vars.items(): - if da.dtype == object and VLenUTF8 is not None: - encoding[var]["object_codec"] = VLenUTF8() + # compressor & object codec + encoding = { + "values": { + "compressor": compressor, + "chunks": tuple(chunking.values()), + } + } + if ds["values"].dtype == object and VLenUTF8 is not None: + encoding["values"]["object_codec"] = VLenUTF8() + + # Ensure coords are writable (handle object/string coords) + # If VLenUTF8 is available, set encoding for object coords; otherwise cast + # to str + for cname, coord in ds.coords.items(): + if coord.dtype == object: + if VLenUTF8 is not None: + encoding[cname] = { + "object_codec": VLenUTF8(), + "compressor": compressor, + } + else: + ds = ds.assign_coords({cname: coord.astype(str)}) # Write if path is not None: - ds.to_zarr(str(path), mode=mode, encoding=encoding) + ds.to_zarr( + str(path), + mode=mode, + encoding=encoding, + ) try: zarr.consolidate_metadata(str(path)) except Exception: pass else: assert store is not None - ds.to_zarr(store=store, group=group or "", mode=mode, encoding=encoding) + ds.to_zarr( + store=store, + group=group_name or "", + mode=mode, + encoding=encoding, + ) # consolidate requires a path; skipping here since we're inside a shared # store @@ -131,22 +163,42 @@ def _df_from_zarr_via_xarray( ) if path is not None: - ds = xr.open_zarr(str(path), consolidated=True, chunks="auto") + ds = xr.open_zarr( + str(path), + consolidated=True, + chunks="auto", + ) else: ds = xr.open_zarr( store=store, - group=group or "", + group=group_name or "", consolidated=False, chunks="auto", ) - # Reconstruct MI if recorded - mi_dim = ds.attrs.get("__pd_mi_dim__") - mi_levels = ds.attrs.get("__pd_mi_levels__") - if mi_dim and mi_levels: - ds = ds.set_index({mi_dim: mi_levels}) + da = ds["values"] + row_dim = list(ds.attrs.get("__row_dims__") or []) + col_dim = list(ds.attrs.get("__col_dims__") or []) + + # Series with MultiIndex index (row_dim, *col_dim) + series = da.to_series() + + # If there are column dims, unstack them back to columns + if col_dim: + df = series.unstack(col_dim) + else: + # No column dims -> a single column DataFrame + df = series.to_frame(name="values") + + col_name = [df.columns.name] + if not (row_dim == df.index.names): + df.index.set_names(row_dim, inplace=True) + if not (col_dim == col_name): + if isinstance(df.columns, pd.MultiIndex): + df.columns.set_names(col_dim, inplace=True) + else: + df.columns.name = col_dim[0] - df = ds.to_dataframe() return df @@ -479,8 +531,6 @@ def wrapper(*args, **kwargs): _df_to_zarr_via_xarray( result, path=zarr_path, - dim_name=per_call_dim_name or result.index.name or "index", - chunks=per_call_chunks, compressor=compressor, mode="w", ) From 8331c26b3c9aff4121b8f7607a2093d537835f28 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:53:26 +0100 Subject: [PATCH 591/658] move implementation to decorator and use it from there --- pixels/pixels_utils.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 58a3956..70b2922 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -32,6 +32,7 @@ from pixels.error import PixelsError from pixels.configs import * from pixels.constants import * +from pixels.decorators import _df_to_zarr_via_xarray, _df_from_zarr_via_xarray from common_utils import math_utils from common_utils.file_utils import init_memmap, read_hdf5 @@ -911,12 +912,19 @@ def save_spike_chance_zarr( del root["spiked"] if isinstance(spiked, pd.DataFrame): - _write_df_as_zarr( - root, - spiked, + _df_to_zarr_via_xarray( + df=spiked, + store=store, group_name="spiked", compressor=compressor, + mode="w", ) + #_write_df_as_zarr( + # root, + # spiked, + # group_name="spiked", + # compressor=compressor, + #) else: root.create_dataset( "spiked", @@ -957,12 +965,19 @@ def save_spike_chance_zarr( del root["positions"] if isinstance(positions, pd.DataFrame): - _write_df_as_zarr( - root, - positions, + _df_to_zarr_via_xarray( + df=positions, + store=store, group_name="positions", compressor=compressor, + mode="w", ) + #_write_df_as_zarr( + # root, + # positions, + # group_name="positions", + # compressor=compressor, + #) else: root.create_dataset( "positions", From 5462bfe7190d7790feccb2dcf0e35c1a87d2417c Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 17:54:12 +0100 Subject: [PATCH 592/658] start implementing multiprocessing in firing rate convolution --- pixels/signal_utils.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/pixels/signal_utils.py b/pixels/signal_utils.py index abcd3a1..036621f 100644 --- a/pixels/signal_utils.py +++ b/pixels/signal_utils.py @@ -301,6 +301,23 @@ def median_subtraction(data, axis=0): return data - np.median(data, axis=axis, keepdims=True) +def _convolve_worker(shm_kernal, shm_times, sample_rate): + # TODO sep 22 2025: + # CONTINUE HERE! + # attach to shared memory + shm = shared_memory.SharedMemory(name=shm_kernal) + + convolved = convolve1d( + input=times.values, + weights=n_kernel, + output=np.float32, + mode="nearest", + axis=0, + ) * sample_rate # rescale it to second + + return None + + def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): """ Convolve spike times data with 1D gaussian kernel to get spike rate. @@ -327,12 +344,14 @@ def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): # normalise kernel to ensure that the total area under the Gaussian is 1 n_kernel = kernel / np.sum(kernel) + # TODO sept 19 2025: + # implement multiprocessing? if isinstance(times, pd.DataFrame): # convolve with gaussian convolved = convolve1d( input=times.values, weights=n_kernel, - output=float, + output=np.float32, mode='nearest', axis=0, ) * sample_rate # rescale it to second @@ -348,7 +367,7 @@ def convolve_spike_trains(times, sigma=100, size=10, sample_rate=1000): output = convolve1d( input=times, weights=n_kernel, - output=float, + output=np.float32, mode='nearest', axis=0, ) * sample_rate # rescale it to second From bc4e34c5086a50b4a6fc4795fa6cf3d4145124be Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 24 Sep 2025 19:00:35 +0100 Subject: [PATCH 593/658] make sure the name is correct --- pixels/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/experiment.py b/pixels/experiment.py index bfd6453..abb3663 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -614,4 +614,4 @@ def get_spike_chance(self, *args, **kwargs): name = session.name chance[name] = session.get_spike_chance(*args, **kwargs) - return df + return chance From d83aa0d009a0c7d1136e501e1756593ba1871588 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 29 Sep 2025 13:23:19 +0100 Subject: [PATCH 594/658] use PixelsError to track errors --- pixels/stream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index af5cc26..e25a6d8 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -15,6 +15,7 @@ from pixels.configs import * from pixels.constants import * from pixels.decorators import cacheable +from pixels.error import PixelsError from common_utils import file_utils @@ -73,6 +74,7 @@ def _map_trials(self, label, event, end_event=None): # only take starts from selected trials selected_starts = trials[np.where(np.isin(trials, starts))[0]] + raise PixelsError("\n> Why would we have more ends than starts?") start_t = timestamps[selected_starts] # only take ends from selected trials selected_ends = trials[np.where(np.isin(trials, ends))[0]] From e8f5deb6f8b2931fd7e39ad4d2188d78b59e4aee Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 29 Sep 2025 13:24:36 +0100 Subject: [PATCH 595/658] make sure to only include trials with both start and end events --- pixels/stream.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index e25a6d8..8702d88 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -72,15 +72,24 @@ def _map_trials(self, label, event, end_event=None): # map starts by end event ends = np.where(np.bitwise_and(events, end_event))[0] - # only take starts from selected trials - selected_starts = trials[np.where(np.isin(trials, starts))[0]] + # only take starts and ends from selected trials + selected_starts = trials[np.isin(trials, starts)] + selected_ends = trials[np.isin(trials, ends)] + + # make sure trials have both starts and ends + start_ids = synched_vr.iloc[selected_starts].trial_count.unique() + end_ids = synched_vr.iloc[selected_ends].trial_count.unique() + if len(start_ids) > len(end_ids): + selected_starts = selected_starts[np.isin(start_ids, end_ids)] + elif len(start_ids) < len(end_ids): + selected_ends = selected_ends[np.isin(end_ids, start_ids)] raise PixelsError("\n> Why would we have more ends than starts?") + + # get timestamps start_t = timestamps[selected_starts] - # only take ends from selected trials - selected_ends = trials[np.where(np.isin(trials, ends))[0]] end_t = timestamps[selected_ends] - # use original trial id as trial index + # use original trial ids as trial index trial_ids = pd.Index( synched_vr.iloc[selected_starts].trial_count.unique() ) From 6b197f882979063c81a5b6f019be3d676960e293 Mon Sep 17 00:00:00 2001 From: amz_office Date: Mon, 29 Sep 2025 13:28:08 +0100 Subject: [PATCH 596/658] make sure to only get the starting positions of the included trials --- pixels/stream.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 8702d88..e73363c 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -124,31 +124,36 @@ def _get_vr_positions(self, label, event, end_event): all_pos_val = all_pos.to_numpy() all_pos_idx = all_pos.index.to_numpy() + # find start and end position index + trial_start_t = np.searchsorted(all_pos_idx, start_t, side="left") + trial_end_t = np.searchsorted(all_pos_idx, end_t, side="right") + trials_positions = [ + pd.Series(all_pos_val[s:e]) + for s, e in zip(trial_start_t, trial_end_t) + ] + positions = pd.concat(trials_positions, axis=1) + # map actual starting locations if not "trial_start" in event.name: all_start_idx = np.flatnonzero(events & event.trial_start) - start_idx = trials[np.isin(trials, all_start_idx)] + trial_start_idx = trials[np.isin(trials, all_start_idx)] + # make sure to only get the included trials' starting positions, + # i.e., the one with start and end event, not all trials in the + # label by aligning trial ids + all_ids = synched_vr.trial_count.iloc[trial_start_idx].values + start_idx = trial_start_idx[np.isin(all_ids, trial_ids)] else: start_idx = selected_starts.copy() # get start positions start_pos = all_pos_val[start_idx].astype(int) - # create multiindex with starts cols_with_starts = pd.MultiIndex.from_arrays( [start_pos, trial_ids], names=("start", "trial"), ) - # find start and end position index - trial_start_t = np.searchsorted(all_pos_idx, start_t, side="left") - trial_end_t = np.searchsorted(all_pos_idx, end_t, side="right") - - trials_positions = [ - pd.Series(all_pos_val[s:e]) - for s, e in zip(trial_start_t, trial_end_t) - ] - positions = pd.concat(trials_positions, axis=1) + # add level with start positions positions.columns = cols_with_starts positions = positions.sort_index(axis=1, ascending=[False, True]) positions.index.name = "time" From 3c9785480e46df8a92ca1a05994bb14a5d70bbf5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 1 Oct 2025 18:35:15 +0100 Subject: [PATCH 597/658] add alpha --- pixels/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/constants.py b/pixels/constants.py index d386c9e..7ba414c 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -28,3 +28,5 @@ # chunking for zarr SMALL_CHUNKS = 64 BIG_CHUNKS = 1024 + +ALPHA = 0.05 From 4f419a2dd23021169037535617fd253f4b93de80 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 1 Oct 2025 18:35:47 +0100 Subject: [PATCH 598/658] make sure to use the full base name --- pixels/decorators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 71a6782..4f38dd2 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -446,7 +446,7 @@ def wrapper(*args, **kwargs): # HDF5 backend if backend == "hdf5": - cache_path = base.with_suffix(".h5") + cache_path = base.with_name(base.name + ".h5") if cache_path.exists() and inst._use_cache != "overwrite": try: df = ioutils.read_hdf5(cache_path) @@ -497,7 +497,7 @@ def wrapper(*args, **kwargs): raise ImportError( "cache_format='zarr' requires zarr. pip install zarr numcodecs xarray" ) - zarr_path = base.with_suffix(".zarr") + zarr_path = base.with_name(base.name + ".h5") can_read = zarr_path.exists() and inst._use_cache != "overwrite" if can_read: From 561510cdea37bd341bbd4c9d34a6bbfef07c07da Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 1 Oct 2025 18:36:24 +0100 Subject: [PATCH 599/658] use GLM to find landmark responsive units --- pixels/behaviours/base.py | 30 ++++ pixels/pixels_utils.py | 335 ++++++++++++++++++++++++++++++++++++++ pixels/stream.py | 21 +++ 3 files changed, 386 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index d949990..6e8188d 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2719,3 +2719,33 @@ def get_spatial_psd( ) return output + + + def get_landmark_responsives( + self, label, event, end_event=None, sigma=None, units=None, + pos_bin=None, + ): + output = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting landmark responsive units from {units} in " + f"<{label.name}> trials." + ) + output[stream_id] = stream.get_landmark_responsives( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + pos_bin=pos_bin, + ) + + return output diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 70b2922..7cd38f8 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -18,6 +18,11 @@ import numpy as np import pandas as pd +from scipy import stats +import statsmodels.formula.api as smf +from statsmodels.stats.multitest import multipletests +from patsy import build_design_matrices + import spikeinterface as si import spikeinterface.extractors as se import spikeinterface.sorters as ss @@ -1872,3 +1877,333 @@ def _read_df_from_zarr(root, group_name: str) -> pd.DataFrame: def _default_names(names: list[str | None], prefix: str) -> list[str]: # Replace None level names with defaults: f"{prefix}{i}" return [n if n is not None else f"{prefix}{i}" for i, n in enumerate(names)] + + +# >>> landmark responsive helpers >>> +def to_df(mean, std, zone): + out = pd.DataFrame({"mean": mean, "std": std}).reset_index() + + # Keep only start and trial; unit is constant + out = out.rename( + columns={"level_0": "start", "level_1": "unit", "level_2": "trial"} + ) + out["zone"] = zone + + # map data type + out["start"] = out["start"].astype(str) + out["unit"] = out["unit"].astype(str) + out["trial"] = out["trial"].astype(str) + + return out + + +# Build linear contrasts for any model that uses patsy coding +def compute_contrast(fit, newdf_a, newdf_b=None): + """ + Returns estimate and SE for: + L'beta where L = X(newdf_a) - X(newdf_b) if newdf_b is provided, + else L = X(newdf_a) + fit: statsmodels results with fixed effects (OLS or MixedLM) + newdf_a/newdf_b: small DataFrames with columns used in the formula (start, + zone) + """ + # For MixedLM, use fixed-effects params/cov + if hasattr(fit, "fe_params"): + fe_params = fit.fe_params + cov = fit.cov_params().loc[fe_params.index, fe_params.index] + cols = fe_params.index + else: + # OLS + fe_params = fit.params + cov = fit.cov_params() + cols = fit.params.index + + di = fit.model.data.design_info + + Xa = build_design_matrices([di], newdf_a)[0] + Xa = np.asarray(Xa) # column order matches the fit + if Xa.shape[1] != len(cols): + # rare: ensure columns align if needed + raise ValueError( + "Design column mismatch; " + "ensure newdf has the same factors and levels as the fit." + ) + + if newdf_b is not None: + Xb = build_design_matrices([di], newdf_b)[0] + Xb = np.asarray(Xb) + L = (Xa - Xb).ravel() + else: + L = Xa.ravel() + + est = float(L @ fe_params.values) + se = float(np.sqrt(L @ cov.values @ L)) + + return est, se + + +# >>> single unit mixed model +def fit_per_unit_ols(df, formula, unit_id): + """ + Step 1 + Fit mean fr of pre-wall, landmark, and post-wall from each trial, each unit + to GLM with cluster-robust SE. + """ + d = df[df["unit"] == str(unit_id)].copy() + if d.empty: + raise ValueError(f"No data for unit {unit_id}") + # OLS with cluster-robust SE by trial + fit = smf.ols(formula, data=d).fit( + cov_type="cluster", + cov_kwds={"groups": d["trial"]}, + ) + return fit + + +def test_diff_any(fit, starts, use_f=True): + """ + Step 2 + Use Wald test on linear contrasts to test if jointly, all these contrasts + are 0. + i.e., this test if there are any difference among the mean fr comparisons. + if wald p < alpha, then there is a significant difference, we do post-hoc to + see where the difference come from; + if wald p > alpha, the unit does not have different fr between landmark & + pre-wall, and landmark & post-wall, or pre-wall & post-wall. + """ + Ls = [] + for s in starts: + for ref in ["pre_wall", "post_wall"]: + row = _L_row( + fit=fit, + start_label=s, + a_zone="landmark", + b_zone=ref, + ) + Ls.append(row) + + R = np.vstack(Ls) + w = fit.wald_test(R, scalar=True, use_f=use_f) + results = { + "stat": float(w.statistic), + "p": float(w.pvalue), + "df_num": int(getattr(w, "df_num", R.shape[0])), + "df_denom": float(getattr(w, "df_denom", np.nan)), + "k": R.shape[0], + } + + return results["p"] + + +def _L_row(fit, start_label, a_zone, b_zone): + """ + Build linear contrasts of zones for a given starting position. + """ + di = fit.model.data.design_info + Xa = np.asarray( + build_design_matrices( + [di], + pd.DataFrame({"start":[start_label], "zone":[a_zone]}), + )[0] + ).ravel() + Xb = np.asarray( + build_design_matrices( + [di], + pd.DataFrame({"start":[start_label], "zone":[b_zone]}) + )[0] + ).ravel() + + return Xa - Xb + + +def family_comparison(fit, starts, compare_to="pre_wall", use_f=True): + """ + Step 3 + Family level mean comparison, i.e., compare pre or post wall with landmark. + """ + R = np.vstack( + [_L_row(fit, s, "landmark", compare_to) for s in starts] + ) + w = fit.wald_test(R, scalar=True, use_f=use_f) + + results = { + "family": f"LM-{ 'Pre' if compare_to=='pre_wall' else 'Post' }", + "stat": float(w.statistic), + "p": float(w.pvalue), + "df_num": int(getattr(w, "df_num", R.shape[0])), + "df_denom": float(getattr(w, "df_denom", np.nan)), + "n_starts": len(starts), + "R": R + } + return results["p"] + + +def start_contrasts_ols(fit, starts, use_normal=True): + """ + Step 4 + Post-hoc test to see where the difference in contrast come from, i.e., get + the linear contrast for each starting positions. + """ + params = fit.params if not hasattr(fit, "fe_params")\ + else fit.fe_params + cov = fit.cov_params() if not hasattr(fit, "fe_params")\ + else fit.cov_params().loc[params.index, params.index] + + rows = [] + for s in sorted(starts, key=str): + df_pre = pd.DataFrame({"start":[s], "zone":["pre_wall"]}) + df_lm = pd.DataFrame({"start":[s], "zone":["landmark"]}) + df_post = pd.DataFrame({"start":[s], "zone":["post_wall"]}) + + for label, A, B in [("lm-pre", df_lm, df_pre), + ("lm-post", df_lm, df_post)]: + est, se = compute_contrast(fit, A, B) + if not np.isfinite(se) or se <= 0: + stat = p = np.nan + else: + stat = est / se + p = float( + 2 * (stats.norm.sf(abs(stat)) if use_normal + else stats.t.sf(abs(stat), df=fit.df_resid)) + ) + rows.append({ + "start": s, + "contrast": label, + "coef": est, + "se": se, + "stat": stat, + "p": p, + }) + + out = pd.DataFrame(rows) + if not out.empty and out["p"].notna().any(): + # get Holm-adjusted p value to correct for multiple comparison + out["p_holm"] = multipletests( + out["p"], + alpha=ALPHA, + method="holm", + )[1] + + return out + + +def test_start_x_zone_interaction_ols(fit): + # Wald test: all interaction terms = 0 + ix = [i for i, zone in enumerate(fit.params.index) if ":" in zone] + if not ix: + return np.nan + R = np.zeros((len(ix), len(fit.params))) + for r, i in enumerate(ix): + R[r, i] = 1.0 + w = fit.wald_test(R, scalar=True, use_f=False) + stat = w.statistic + + return float(w.statistic), float(w.pvalue) +# <<< single unit mixed model +# <<< landmark responsive helpers <<< + + +def get_landmark_responsives(pos_fr, units, pos_bin): + from vision_in_darkness.constants import landmarks + + # get on & off of landmark and stack them + lms = landmarks[1:-2].reshape((-1, 2)) + + # get pre & post wall + wall_on = landmarks[:-1][::2] + pos_bin + wall_off = landmarks[:-1][1::2] - pos_bin + pre_walls = np.column_stack([wall_on[:-1], wall_off[:-1]]) + post_walls = np.column_stack([wall_on[1:], wall_off[1:]]) + + # group pre wall, landmarks, and post wall together + chunks = np.stack([pre_walls, lms, post_walls], axis=1) + + ons = chunks[..., 0] + offs = chunks[..., 1] + + # get all positions + positions = pos_fr.index.to_numpy() + # build mask for all positions + mask = (positions[:, None, None] >= ons[None, :, :])\ + & (positions[:, None, None] <= offs[None, :, :]) + + lm_responsive_bool = np.zeros( + (len(units), lms.shape[0]) + ).astype(bool) + responsives = pd.DataFrame( + lm_responsive_bool, + index=units, + columns=np.arange(lms.shape[0]), + ) + responsives.index.name = "unit" + responsives.columns.name = "landmark" + + for l in range(lms.shape[0]): + # get all data within each chunk + chunk = pos_fr.loc[mask[:, l, :]].dropna(axis=1) + + # build mask chunk positions + chunk_pos = chunk.index.values + chunk_mask = ( + chunk_pos[:, None] >= ons[l, :] + ) & (chunk_pos[:, None] <= offs[l, :]) + + # get mean & std of walls and landmark + pre_wall = chunk.loc[chunk_mask[:, 0]].dropna(axis=1) + pre_wall_mean = pre_wall.mean(axis=0) + pre_wall_std = pre_wall.std(axis=0) + + landmark = chunk.loc[chunk_mask[:, 1]].dropna(axis=1) + landmark_mean = landmark.mean(axis=0) + landmark_std = landmark.std(axis=0) + + post_wall = chunk.loc[chunk_mask[:, 2]].dropna(axis=1) + post_wall_mean = post_wall.mean(axis=0) + post_wall_std = post_wall.std(axis=0) + + # aggregate + agg = pd.concat([ + to_df(pre_wall_mean, pre_wall_std, "pre_wall"), + to_df(landmark_mean, landmark_std, "landmark"), + to_df(post_wall_mean, post_wall_std, "post_wall"), + ], + ignore_index=True, + ) + agg["zone"] = pd.Categorical( + agg["zone"], + categories=["pre_wall", "landmark", "post_wall"], + ordered=True, + ) + + # get all starting positions + starts = agg.start.unique() + + # create model formula + min_start = min(starts) + simple_model = ( + "mean ~ C(zone, Treatment(reference='pre_wall'))" + ) + full_model = ( + f"""mean + ~ C(start, Treatment(reference={min_start})) + * C(zone, Treatment(reference='pre_wall'))""" + ) + + for unit_id in units: + unit_fit = fit_per_unit_ols( + df=agg, + formula=full_model, + #formula=simple_model, + unit_id=unit_id, + ) + #print(unit_fit.summary()) + # step 4: check contrast at each start + unit_contrasts = start_contrasts_ols( + fit=unit_fit, + starts=starts, + ) + if (unit_contrasts.coef > 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id, l] = True + + return responsives diff --git a/pixels/stream.py b/pixels/stream.py index e73363c..53f8c2f 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1278,3 +1278,24 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): psds = xut.save_chance_psd(self.BEHAVIOUR_SAMPLE_RATE, positions, paths) return psds + + + @cacheable + def get_landmark_responsives( + self, units, label, event, sigma, end_event, pos_bin, + ): + pos_fr = self.get_positional_data( + units=units, # NOTE: ALWAYS the first arg + label=label, + event=event, + sigma=sigma, + end_event=end_event, # NOTE: ALWAYS the last arg + )["pos_fr"] + + responsives = xut.get_landmark_responsives( + pos_fr=pos_fr, + units=units[self.stream_id], + pos_bin=pos_bin, + ) + + return responsives From eb3a996f013f99c88d4e3de29df355202bd0b0b6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 16:36:26 +0100 Subject: [PATCH 600/658] only inject reserved kwarg if the method accepts --- pixels/decorators.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 4f38dd2..07d732c 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -368,6 +368,25 @@ def read_from_group(prefix: str) -> Any: return read_from_group("") +def _filter_reserved_kwargs(fn, reserved: dict[str, Any]) -> dict[str, Any]: + """ + Return subset of `reserved` that `fn` will accept (has a parameter by that + name, or has **kwargs). + """ + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return {} + # if method has **kwargs, pass everything + if any( + p.kind == inspect.Parameter.VAR_KEYWORD + for p in sig.parameters.values() + ): + return reserved + + # otherwise pass only the names it declares + return {k: v for k, v in reserved.items() if k in sig.parameters} + # ----------------------- # Decorator with Zarr # ----------------------- @@ -509,8 +528,9 @@ def wrapper(*args, **kwargs): logging.info(f"\n> Failed to read Zarr cache ({e}); recomputing.") # inject reserved kwargs so the method can write directly to - # store - kwargs["_zarr_out"] = zarr_path + # store, if the method accepts + reserved = {"_zarr_out": zarr_path} + kwargs.update(_filter_reserved_kwargs(method, reserved)) # Compute fresh result = method(*args, **kwargs) From bbc7b75e31e7a85a93a2b7ed0db566c3c2177e35 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 16:36:41 +0100 Subject: [PATCH 601/658] zarr suffix is zarr --- pixels/decorators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 07d732c..9c4ef60 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -5,6 +5,7 @@ from pathlib import Path from functools import wraps from typing import Any +import inspect import numpy as np import pandas as pd @@ -516,7 +517,7 @@ def wrapper(*args, **kwargs): raise ImportError( "cache_format='zarr' requires zarr. pip install zarr numcodecs xarray" ) - zarr_path = base.with_name(base.name + ".h5") + zarr_path = base.with_name(base.name + ".zarr") can_read = zarr_path.exists() and inst._use_cache != "overwrite" if can_read: From d5283e27b95623f8629a56f1c95d53b39ffb0692 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 16:38:04 +0100 Subject: [PATCH 602/658] add missing method --- pixels/decorators.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/decorators.py b/pixels/decorators.py index 9c4ef60..4ecc48f 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -42,6 +42,10 @@ def _make_default_compressor() -> Any: # xarray <-> DataFrame helpers # --------------------------- +def _default_names(names: list[str | None], prefix: str) -> list[str]: + # Replace None level names with defaults: f"{prefix}{i}" + return [n if n is not None else f"{prefix}{i}" for i, n in enumerate(names)] + def _df_to_zarr_via_xarray( df: pd.DataFrame, *, From 57c3d35312d04ae984184d4643c371cfeb3f2102 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 16:38:31 +0100 Subject: [PATCH 603/658] explicitly name stat --- pixels/pixels_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7cd38f8..4f98196 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2066,14 +2066,10 @@ def start_contrasts_ols(fit, starts, use_normal=True): 2 * (stats.norm.sf(abs(stat)) if use_normal else stats.t.sf(abs(stat), df=fit.df_resid)) ) - rows.append({ - "start": s, - "contrast": label, - "coef": est, - "se": se, - "stat": stat, - "p": p, - }) + col_names = ["start", "contrast", "coef", "SE", "stat", "p"] + rows.append( + dict(zip(col_names, [s, label, est, se, stat, p])) + ) out = pd.DataFrame(rows) if not out.empty and out["p"].notna().any(): @@ -2084,6 +2080,10 @@ def start_contrasts_ols(fit, starts, use_normal=True): method="holm", )[1] + # rename 'stat' to 'z' or 't' + stat_label = "z" if use_normal else "t" + out = out.rename(columns={"stat": stat_label}) + return out From 580161500f0d7573627392b722fafa13fa0fd7b9 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 18:08:19 +0100 Subject: [PATCH 604/658] cast to str if dtype is object --- pixels/decorators.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 4ecc48f..4b94ba5 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -12,11 +12,10 @@ from tables import HDF5ExtError try: import zarr - from numcodecs import Blosc, VLenUTF8 + from numcodecs import Blosc except Exception: zarr = None Blosc = None - VLenUTF8 = None try: import xarray as xr @@ -105,28 +104,15 @@ def _df_to_zarr_via_xarray( if compressor is None: compressor = _make_default_compressor() + # compressor & object codec encoding = { - "values": { - "compressor": compressor, - "chunks": tuple(chunking.values()), - } + "values": {"compressor": compressor} } - if ds["values"].dtype == object and VLenUTF8 is not None: - encoding["values"]["object_codec"] = VLenUTF8() - - # Ensure coords are writable (handle object/string coords) - # If VLenUTF8 is available, set encoding for object coords; otherwise cast - # to str + # ensure coords are writable (handle object/string coords): cast to str for cname, coord in ds.coords.items(): if coord.dtype == object: - if VLenUTF8 is not None: - encoding[cname] = { - "object_codec": VLenUTF8(), - "compressor": compressor, - } - else: - ds = ds.assign_coords({cname: coord.astype(str)}) + ds = ds.assign_coords({cname: coord.astype(str)}) # Write if path is not None: @@ -147,8 +133,6 @@ def _df_to_zarr_via_xarray( mode=mode, encoding=encoding, ) - # consolidate requires a path; skipping here since we're inside a shared - # store def _df_from_zarr_via_xarray( From 485bbd87e1eb69e31e186c79b0f900cf850723ef Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 18:08:38 +0100 Subject: [PATCH 605/658] import constants --- pixels/decorators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pixels/decorators.py b/pixels/decorators.py index 4b94ba5..01e398f 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -23,6 +23,7 @@ xr = None from pixels.configs import * +from pixels.constants import * from pixels import ioutils from pixels.error import PixelsError from pixels.units import SelectedUnits From a6c19d2e30829530da44045e2e14e6261a2d6f1d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 18:09:36 +0100 Subject: [PATCH 606/658] if row or index name is None, set it first --- pixels/decorators.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pixels/decorators.py b/pixels/decorators.py index 01e398f..63cd83a 100644 --- a/pixels/decorators.py +++ b/pixels/decorators.py @@ -74,15 +74,21 @@ def _df_to_zarr_via_xarray( if isinstance(df.index, pd.MultiIndex): row_names = _default_names(list(df.index.names), row_prefix) else: - row_names = [df.index.name or f"{row_prefix}0"] + if not df.index.name: + df.index.name = f"{row_prefix}0" + row_names = [df.index.name] if isinstance(df.columns, pd.MultiIndex): col_names = _default_names(list(df.columns.names), col_prefix) else: - col_names = [df.columns.name or f"{col_prefix}0"] - - # Stack ALL column levels to move them into the row index; result index levels = row_names + col_names - series = df.stack(col_names, future_stack=True) # Series with MultiIndex index + if not df.columns.name: + df.columns.name = f"{col_prefix}0" + col_names = [df.columns.name] + + # stack ALL column levels to move them into the row index; result index + # levels = row_names + col_names + # Series with MultiIndex index + series = df.stack(col_names, future_stack=True) # Build DataArray (dims are level names of the Series index, in order) da = xr.DataArray.from_series(series).rename("values") From c2a668505858d51ca237e9022fcaddc8b468e771 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 18:11:26 +0100 Subject: [PATCH 607/658] encode both positive and negative response --- pixels/pixels_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 4f98196..3801660 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2104,6 +2104,12 @@ def test_start_x_zone_interaction_ols(fit): def get_landmark_responsives(pos_fr, units, pos_bin): + """ + use int8 to encode responsiveness: + 0: not responsive + 1: positively responsive + -1: negatively responsive + """ from vision_in_darkness.constants import landmarks # get on & off of landmark and stack them @@ -2129,7 +2135,7 @@ def get_landmark_responsives(pos_fr, units, pos_bin): lm_responsive_bool = np.zeros( (len(units), lms.shape[0]) - ).astype(bool) + ).astype(np.int8) responsives = pd.DataFrame( lm_responsive_bool, index=units, @@ -2138,7 +2144,9 @@ def get_landmark_responsives(pos_fr, units, pos_bin): responsives.index.name = "unit" responsives.columns.name = "landmark" + all_contrasts = {} for l in range(lms.shape[0]): + lm_contrasts = {} # get all data within each chunk chunk = pos_fr.loc[mask[:, l, :]].dropna(axis=1) @@ -2204,6 +2212,10 @@ def get_landmark_responsives(pos_fr, units, pos_bin): ) if (unit_contrasts.coef > 0).all()\ and (unit_contrasts.p_holm < ALPHA).all(): - responsives.loc[unit_id, l] = True + responsives.loc[unit_id, l] = 1 + # negative responsive + if (unit_contrasts.coef < 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id, l] = -1 return responsives From c75dfba7272a9208cd4352a1180897fb7a5a0a23 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 18:11:44 +0100 Subject: [PATCH 608/658] save both contrasts and responsiveness map --- pixels/pixels_utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3801660..c03cf2a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2210,6 +2210,9 @@ def get_landmark_responsives(pos_fr, units, pos_bin): fit=unit_fit, starts=starts, ) + lm_contrasts[unit_id] = unit_contrasts + + # positive responsive if (unit_contrasts.coef > 0).all()\ and (unit_contrasts.p_holm < ALPHA).all(): responsives.loc[unit_id, l] = 1 @@ -2218,4 +2221,21 @@ def get_landmark_responsives(pos_fr, units, pos_bin): and (unit_contrasts.p_holm < ALPHA).all(): responsives.loc[unit_id, l] = -1 - return responsives + df = pd.concat( + lm_contrasts, + axis=0, + names=["unit", "index"], + ) + all_contrasts[l] = df.droplevel("index") + + contrasts = pd.concat( + all_contrasts, + axis=0, + names=["landmark", "unit"], + ) + contrasts.columns.name = "stat" + contrasts.start = contrasts.start.astype(int) + # so that row index is unique + contrasts = contrasts.set_index(["start", "contrast"], append=True) + + return {"contrasts": contrasts, "responsives": responsives} From 8877042f07c8839d724e27dce9bb684e6686bd7e Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 2 Oct 2025 18:12:02 +0100 Subject: [PATCH 609/658] cache as zarr --- pixels/stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 53f8c2f..69c83d2 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1280,7 +1280,7 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): return psds - @cacheable + @cacheable(cache_format="zarr") def get_landmark_responsives( self, units, label, event, sigma, end_event, pos_bin, ): @@ -1292,10 +1292,10 @@ def get_landmark_responsives( end_event=end_event, # NOTE: ALWAYS the last arg )["pos_fr"] - responsives = xut.get_landmark_responsives( + output = xut.get_landmark_responsives( pos_fr=pos_fr, units=units[self.stream_id], pos_bin=pos_bin, ) - return responsives + return output From 5ae4e47dfcd3b10cc66a9fcd629bf035badb1d07 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:07:41 +0100 Subject: [PATCH 610/658] no need to pass event since we will loop through them in stream.py --- pixels/behaviours/base.py | 5 +---- pixels/stream.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 6e8188d..b1bf338 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2722,8 +2722,7 @@ def get_spatial_psd( def get_landmark_responsives( - self, label, event, end_event=None, sigma=None, units=None, - pos_bin=None, + self, label, sigma=None, units=None, pos_bin=None, ): output = {} streams = self.files["pixels"] @@ -2742,9 +2741,7 @@ def get_landmark_responsives( output[stream_id] = stream.get_landmark_responsives( units=units, label=label, - event=event, sigma=sigma, - end_event=end_event, pos_bin=pos_bin, ) diff --git a/pixels/stream.py b/pixels/stream.py index 69c83d2..cf1c8aa 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1281,9 +1281,6 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): @cacheable(cache_format="zarr") - def get_landmark_responsives( - self, units, label, event, sigma, end_event, pos_bin, - ): pos_fr = self.get_positional_data( units=units, # NOTE: ALWAYS the first arg label=label, @@ -1291,6 +1288,7 @@ def get_landmark_responsives( sigma=sigma, end_event=end_event, # NOTE: ALWAYS the last arg )["pos_fr"] + def get_landmark_responsives(self, units, label, sigma, pos_bin): output = xut.get_landmark_responsives( pos_fr=pos_fr, From ad877fcd899bb8aba5d76ea18d6d4b2805d7a1bd Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:08:48 +0100 Subject: [PATCH 611/658] only copy .bin when necessary --- pixels/stream.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index cf1c8aa..4f2d874 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -47,8 +47,8 @@ def __repr__(self): return f"" - def load_raw_ap(self): - paths = [self.session.find_file(path, copy=True) for path in self.files["ap_raw"]] + def load_raw_ap(self, copy=False): + paths = [self.session.find_file(path, copy=copy) for path in self.files["ap_raw"]] self.files["si_rec"] = xut.load_raw(paths, self.stream_id) return self.files["si_rec"] @@ -987,7 +987,7 @@ def get_positional_data( def preprocess_raw(self): # load raw ap - raw_rec = self.load_raw_ap() + raw_rec = self.load_raw_ap(copy=True) # load brain surface depths depth_info = file_utils.load_yaml( @@ -1015,7 +1015,7 @@ def extract_bands(self, freqs, preprocess=True): self.preprocess_raw() rec = self.files["preprocessed"] else: - rec = self.load_raw_ap() + rec = self.load_raw_ap(copy=True) if freqs == None: bands = freq_bands From 78f5fee061112c2a4b9e4bd2cb15cbc98b83eb01 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:09:51 +0100 Subject: [PATCH 612/658] simplify --- pixels/stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 4f2d874..d763ea7 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -66,11 +66,11 @@ def _map_trials(self, label, event, end_event=None): timestamps = action_labels["timestamps"] # select frames of wanted trial type - trials = np.where(np.bitwise_and(outcomes, label))[0] + trials = np.flatnonzero(outcomes & label) # map starts by event - starts = np.where(np.bitwise_and(events, event))[0] + starts = np.flatnonzero(events & event) # map starts by end event - ends = np.where(np.bitwise_and(events, end_event))[0] + ends = np.flatnonzero(events & end_event) # only take starts and ends from selected trials selected_starts = trials[np.isin(trials, starts)] From 42270e7de682d54934267cd12085b00db2036de1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:10:46 +0100 Subject: [PATCH 613/658] when aligning to a landmark event in dark, make sure it is in dark --- pixels/stream.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index d763ea7..b5ea271 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -72,6 +72,15 @@ def _map_trials(self, label, event, end_event=None): # map starts by end event ends = np.flatnonzero(events & end_event) + if ("dark" in label.name) and ("landmark" in event.name): + # get dark onset and offset edges + dark_on = ((events & event.dark_on) != 0).astype(np.int8) + dark_off = ((events & event.dark_off) != 0).astype(np.int8) + # make in dark boolean + in_dark = np.cumsum(dark_on - dark_off) > 0 + # get starts only when also in dark + starts = np.flatnonzero(((events & event) != 0) & in_dark) + # only take starts and ends from selected trials selected_starts = trials[np.isin(trials, starts)] selected_ends = trials[np.isin(trials, ends)] From eb9dd5d74dffc421f8c6053ea24f54d70bbcf38d Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:11:19 +0100 Subject: [PATCH 614/658] always check the common trial ids of starts and ends and correct both --- pixels/stream.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index b5ea271..a8a88a6 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -85,23 +85,21 @@ def _map_trials(self, label, event, end_event=None): selected_starts = trials[np.isin(trials, starts)] selected_ends = trials[np.isin(trials, ends)] - # make sure trials have both starts and ends + # make sure trials have both starts and ends, some trials ended before + # the end_event, and some dark trials are not in dark at start event start_ids = synched_vr.iloc[selected_starts].trial_count.unique() end_ids = synched_vr.iloc[selected_ends].trial_count.unique() - if len(start_ids) > len(end_ids): - selected_starts = selected_starts[np.isin(start_ids, end_ids)] - elif len(start_ids) < len(end_ids): - selected_ends = selected_ends[np.isin(end_ids, start_ids)] - raise PixelsError("\n> Why would we have more ends than starts?") + common_ids = np.intersect1d(start_ids, end_ids) + if len(start_ids) != len(end_ids): + selected_starts = selected_starts[np.isin(start_ids, common_ids)] + selected_ends = selected_ends[np.isin(end_ids, common_ids)] # get timestamps start_t = timestamps[selected_starts] end_t = timestamps[selected_ends] # use original trial ids as trial index - trial_ids = pd.Index( - synched_vr.iloc[selected_starts].trial_count.unique() - ) + trial_ids = pd.Index(common_ids) return trials, events, selected_starts, start_t, end_t, trial_ids From 43f283530b5cdfd35622b1990e21dcfff26d5cd2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:16:48 +0100 Subject: [PATCH 615/658] move landmark loop to stream.py --- pixels/pixels_utils.py | 200 ++++++++++++++++++----------------------- pixels/stream.py | 70 ++++++++++++--- 2 files changed, 145 insertions(+), 125 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c03cf2a..7d221e0 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2103,139 +2103,113 @@ def test_start_x_zone_interaction_ols(fit): # <<< landmark responsive helpers <<< -def get_landmark_responsives(pos_fr, units, pos_bin): +def get_landmark_responsives(pos_fr, units, ons, offs): """ use int8 to encode responsiveness: 0: not responsive 1: positively responsive -1: negatively responsive """ - from vision_in_darkness.constants import landmarks - - # get on & off of landmark and stack them - lms = landmarks[1:-2].reshape((-1, 2)) - - # get pre & post wall - wall_on = landmarks[:-1][::2] + pos_bin - wall_off = landmarks[:-1][1::2] - pos_bin - pre_walls = np.column_stack([wall_on[:-1], wall_off[:-1]]) - post_walls = np.column_stack([wall_on[1:], wall_off[1:]]) - - # group pre wall, landmarks, and post wall together - chunks = np.stack([pre_walls, lms, post_walls], axis=1) - - ons = chunks[..., 0] - offs = chunks[..., 1] + units = units.flat() # get all positions positions = pos_fr.index.to_numpy() # build mask for all positions - mask = (positions[:, None, None] >= ons[None, :, :])\ - & (positions[:, None, None] <= offs[None, :, :]) + position_mask = (positions[:, None] >= ons)\ + & (positions[:, None] <= offs) + + # get pre wall and trial mask + pre_wall = pos_fr.loc[ + position_mask[:, 0], : + ].dropna(axis=1, how="any") + trials_pre_wall = pre_wall.columns.get_level_values( + "trial" + ).unique() + trial_mask = pos_fr.columns.get_level_values( + "trial" + ).isin(trials_pre_wall) + + # get mean & std of walls and landmark + landmark = pos_fr.loc[ + position_mask[:, 1], trial_mask + ].dropna(axis=1, how="any") + post_wall = pos_fr.loc[ + position_mask[:, 2], trial_mask + ].dropna(axis=1, how="any") + + pre_wall_mean = pre_wall.mean(axis=0) + pre_wall_std = pre_wall.std(axis=0) + + landmark_mean = landmark.mean(axis=0) + landmark_std = landmark.std(axis=0) + + post_wall_mean = post_wall.mean(axis=0) + post_wall_std = post_wall.std(axis=0) + + # aggregate + agg = pd.concat([ + to_df(pre_wall_mean, pre_wall_std, "pre_wall"), + to_df(landmark_mean, landmark_std, "landmark"), + to_df(post_wall_mean, post_wall_std, "post_wall"), + ], + ignore_index=True, + ) + agg["zone"] = pd.Categorical( + agg["zone"], + categories=["pre_wall", "landmark", "post_wall"], + ordered=True, + ) + + # get all starting positions + starts = agg.start.unique() + + # create model formula + min_start = min(starts) + simple_model = ( + "mean ~ C(zone, Treatment(reference='pre_wall'))" + ) + full_model = ( + f"""mean + ~ C(start, Treatment(reference={min_start!r})) + * C(zone, Treatment(reference='pre_wall'))""" + ) - lm_responsive_bool = np.zeros( - (len(units), lms.shape[0]) - ).astype(np.int8) - responsives = pd.DataFrame( + lm_responsive_bool = np.zeros(len(units)).astype(np.int8) + responsives = pd.Series( lm_responsive_bool, index=units, - columns=np.arange(lms.shape[0]), ) responsives.index.name = "unit" - responsives.columns.name = "landmark" - - all_contrasts = {} - for l in range(lms.shape[0]): - lm_contrasts = {} - # get all data within each chunk - chunk = pos_fr.loc[mask[:, l, :]].dropna(axis=1) - - # build mask chunk positions - chunk_pos = chunk.index.values - chunk_mask = ( - chunk_pos[:, None] >= ons[l, :] - ) & (chunk_pos[:, None] <= offs[l, :]) - - # get mean & std of walls and landmark - pre_wall = chunk.loc[chunk_mask[:, 0]].dropna(axis=1) - pre_wall_mean = pre_wall.mean(axis=0) - pre_wall_std = pre_wall.std(axis=0) - - landmark = chunk.loc[chunk_mask[:, 1]].dropna(axis=1) - landmark_mean = landmark.mean(axis=0) - landmark_std = landmark.std(axis=0) - - post_wall = chunk.loc[chunk_mask[:, 2]].dropna(axis=1) - post_wall_mean = post_wall.mean(axis=0) - post_wall_std = post_wall.std(axis=0) - - # aggregate - agg = pd.concat([ - to_df(pre_wall_mean, pre_wall_std, "pre_wall"), - to_df(landmark_mean, landmark_std, "landmark"), - to_df(post_wall_mean, post_wall_std, "post_wall"), - ], - ignore_index=True, - ) - agg["zone"] = pd.Categorical( - agg["zone"], - categories=["pre_wall", "landmark", "post_wall"], - ordered=True, - ) - - # get all starting positions - starts = agg.start.unique() - # create model formula - min_start = min(starts) - simple_model = ( - "mean ~ C(zone, Treatment(reference='pre_wall'))" + lm_contrasts = {} + for unit_id in units: + unit_fit = fit_per_unit_ols( + df=agg, + formula=full_model, + #formula=simple_model, + unit_id=unit_id, ) - full_model = ( - f"""mean - ~ C(start, Treatment(reference={min_start})) - * C(zone, Treatment(reference='pre_wall'))""" + #print(unit_fit.summary()) + # step 4: check contrast at each start + unit_contrasts = start_contrasts_ols( + fit=unit_fit, + starts=starts, ) + lm_contrasts[unit_id] = unit_contrasts - for unit_id in units: - unit_fit = fit_per_unit_ols( - df=agg, - formula=full_model, - #formula=simple_model, - unit_id=unit_id, - ) - #print(unit_fit.summary()) - # step 4: check contrast at each start - unit_contrasts = start_contrasts_ols( - fit=unit_fit, - starts=starts, - ) - lm_contrasts[unit_id] = unit_contrasts - - # positive responsive - if (unit_contrasts.coef > 0).all()\ - and (unit_contrasts.p_holm < ALPHA).all(): - responsives.loc[unit_id, l] = 1 - # negative responsive - if (unit_contrasts.coef < 0).all()\ - and (unit_contrasts.p_holm < ALPHA).all(): - responsives.loc[unit_id, l] = -1 - - df = pd.concat( - lm_contrasts, - axis=0, - names=["unit", "index"], - ) - all_contrasts[l] = df.droplevel("index") + # positive responsive + if (unit_contrasts.coef > 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id] = 1 + # negative responsive + if (unit_contrasts.coef < 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id] = -1 contrasts = pd.concat( - all_contrasts, + lm_contrasts, axis=0, - names=["landmark", "unit"], - ) - contrasts.columns.name = "stat" - contrasts.start = contrasts.start.astype(int) - # so that row index is unique - contrasts = contrasts.set_index(["start", "contrast"], append=True) + names=["unit", "index"], + ).droplevel("index") - return {"contrasts": contrasts, "responsives": responsives} + return contrasts, responsives diff --git a/pixels/stream.py b/pixels/stream.py index a8a88a6..6e713c1 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1288,19 +1288,65 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): @cacheable(cache_format="zarr") - pos_fr = self.get_positional_data( - units=units, # NOTE: ALWAYS the first arg - label=label, - event=event, - sigma=sigma, - end_event=end_event, # NOTE: ALWAYS the last arg - )["pos_fr"] def get_landmark_responsives(self, units, label, sigma, pos_bin): - output = xut.get_landmark_responsives( - pos_fr=pos_fr, - units=units[self.stream_id], - pos_bin=pos_bin, + from vision_in_darkness.constants import landmarks + from pixels.behaviours.virtual_reality import Events + # get landmark events + landmark_events = [ + value for key, value in Events.__dict__.items() + if not (key.startswith("__") or key.startswith("_") + or callable(value) + or isinstance(value, (classmethod, staticmethod))) + and "landmark" in key + ] + # exclude the black wall and the last landmark + NUM_LANDMARKS = len(landmark_events[2:-2]) // 2 + # get start and end events + end_events = landmark_events[1:-1][3::2] + + # get on & off of landmark and stack them + lms = landmarks[1:-2].reshape((-1, 2)) + + # get pre & post wall + wall_on = landmarks[:-1][::2] + pos_bin + wall_off = landmarks[:-1][1::2] - pos_bin + pre_walls = np.column_stack([wall_on[:-1], wall_off[:-1]]) + post_walls = np.column_stack([wall_on[1:], wall_off[1:]]) + + # group pre wall, landmarks, and post wall together + chunks = np.stack([pre_walls, lms, post_walls], axis=1) + + ons = chunks[..., 0] + offs = chunks[..., 1] + + all_contrasts = {} + resps = {} + for l, _ in enumerate(lms): + pos_fr = self.get_positional_data( + units=units, # NOTE: ALWAYS the first arg + label=label, + sigma=sigma, + end_event=end_events[l], # NOTE: ALWAYS the last arg + )["pos_fr"] + + all_contrasts[l], resps[l] = xut.get_landmark_responsives( + pos_fr=pos_fr, + units=units, + ons=ons[l, :], + offs=offs[l, :], + ) + + responsives = pd.concat(resps, axis=1, names="landmark") + + contrasts = pd.concat( + all_contrasts, + axis=0, + names=["landmark", "unit"], ) + contrasts.columns.name = "metrics" + contrasts.start = contrasts.start.astype(int) + # so that row index is unique + contrasts = contrasts.set_index(["start", "contrast"], append=True) - return output + return {"contrasts": contrasts, "responsives": responsives} From 2ce6c9a2be10c54cb4c1b14f0044b4ee753d7caa Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:34:53 +0100 Subject: [PATCH 616/658] loop through start events too --- pixels/stream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 6e713c1..6d8b1d8 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1303,6 +1303,7 @@ def get_landmark_responsives(self, units, label, sigma, pos_bin): # exclude the black wall and the last landmark NUM_LANDMARKS = len(landmark_events[2:-2]) // 2 # get start and end events + start_events = landmark_events[1:-1][0::2][:-1] end_events = landmark_events[1:-1][3::2] # get on & off of landmark and stack them @@ -1326,6 +1327,7 @@ def get_landmark_responsives(self, units, label, sigma, pos_bin): pos_fr = self.get_positional_data( units=units, # NOTE: ALWAYS the first arg label=label, + event=start_events[l], sigma=sigma, end_event=end_events[l], # NOTE: ALWAYS the last arg )["pos_fr"] From bd2a01834fb39d03c7d599e250c4c8536d4220e6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Thu, 9 Oct 2025 17:35:20 +0100 Subject: [PATCH 617/658] drop all nan columns --- pixels/stream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index 6d8b1d8..5a78122 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -164,6 +164,8 @@ def _get_vr_positions(self, label, event, end_event): positions.columns = cols_with_starts positions = positions.sort_index(axis=1, ascending=[False, True]) positions.index.name = "time" + # drop trials with all nan + positions = positions.dropna(axis=1, how="all") return positions From 95a9c16741788cf28d4409c42ee70b31b32afefe Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:50:16 +0100 Subject: [PATCH 618/658] no need to pass pos_bin cuz we now pre-define wall events --- pixels/behaviours/base.py | 5 +---- pixels/behaviours/virtual_reality.py | 24 +++++++++++++++++++----- pixels/stream.py | 4 +--- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index b1bf338..e76cb2b 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2721,9 +2721,7 @@ def get_spatial_psd( return output - def get_landmark_responsives( - self, label, sigma=None, units=None, pos_bin=None, - ): + def get_landmark_responsives(self, label, sigma=None, units=None): output = {} streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): @@ -2742,7 +2740,6 @@ def get_landmark_responsives( units=units, label=label, sigma=sigma, - pos_bin=pos_bin, ) return output diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index 8294162..ed634c6 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -93,26 +93,38 @@ class Events(IntFlag): # positional events pre_dark_end = auto()# 50 cm - # TODO jun 4 2025: - # how to mark wall? - wall = auto()# in between landmarks - # black wall landmark0_on = auto()# wherever trial starts, before 60cm landmark0_off = auto()# 60 cm + # NOTE: wall between landmarks always remove 10cm adjacent + wall1_on = auto()# 70cm + wall1_off = auto()# 100cm + landmark1_on = auto()# 110 cm landmark1_off = auto()# 130 cm + wall2_on = auto()# 140cm + wall2_off = auto()# 180cm + landmark2_on = auto()# 190 cm landmark2_off = auto()# 210 cm + wall3_on = auto()# 220cm + wall3_off = auto()# 260cm + landmark3_on = auto()# 270 cm landmark3_off = auto()# 290 cm + wall4_on = auto()# 300cm + wall4_off = auto()# 340cm + landmark4_on = auto()# 350 cm landmark4_off = auto()# 370 cm + wall5_on = auto()# 380cm + wall5_off = auto()# 420cm + landmark5_on = auto()# 430 cm landmark5_off = auto()# 450 cm @@ -438,8 +450,10 @@ def _first_post_mark(group_df, check_marks): masks[Events.landmark0_off] = landmark0_off # <<< landmark 0 black wall <<< - # >>> landmarks 1 to 5 >>> + # >>> landmarks and wall 1 to 5 >>> landmarks = session.landmarks[1:] + # get walls, excluding adjacent 10cm to landmark + walls = session.mid_walls for l, landmark in enumerate(landmarks): if l % 2 != 0: diff --git a/pixels/stream.py b/pixels/stream.py index 5a78122..7e52337 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1290,7 +1290,7 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): @cacheable(cache_format="zarr") - def get_landmark_responsives(self, units, label, sigma, pos_bin): + def get_landmark_responsives(self, units, label, sigma): from vision_in_darkness.constants import landmarks from pixels.behaviours.virtual_reality import Events @@ -1312,8 +1312,6 @@ def get_landmark_responsives(self, units, label, sigma, pos_bin): lms = landmarks[1:-2].reshape((-1, 2)) # get pre & post wall - wall_on = landmarks[:-1][::2] + pos_bin - wall_off = landmarks[:-1][1::2] - pos_bin pre_walls = np.column_stack([wall_on[:-1], wall_off[:-1]]) post_walls = np.column_stack([wall_on[1:], wall_off[1:]]) From 1940f848f682e37a883ba2937081eacc4951b71c Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:50:53 +0100 Subject: [PATCH 619/658] formatting --- pixels/stream.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 7e52337..372e7c1 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -48,7 +48,10 @@ def __repr__(self): def load_raw_ap(self, copy=False): - paths = [self.session.find_file(path, copy=copy) for path in self.files["ap_raw"]] + paths = [ + self.session.find_file(path, copy=copy) + for path in self.files["ap_raw"] + ] self.files["si_rec"] = xut.load_raw(paths, self.stream_id) return self.files["si_rec"] From 81b4335964371a73ca1ca0e9132308a429181115 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:51:09 +0100 Subject: [PATCH 620/658] in case aligning to wall --- pixels/stream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 372e7c1..440a0c4 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -75,7 +75,8 @@ def _map_trials(self, label, event, end_event=None): # map starts by end event ends = np.flatnonzero(events & end_event) - if ("dark" in label.name) and ("landmark" in event.name): + if ("dark" in label.name) and\ + any(name in event.name for name in ["landmark" or "wall"]): # get dark onset and offset edges dark_on = ((events & event.dark_on) != 0).astype(np.int8) dark_off = ((events & event.dark_off) != 0).astype(np.int8) From 3c650898cf3b5fbb894ea5f2098cac3981cf754e Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:51:32 +0100 Subject: [PATCH 621/658] show events of alignment too --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 440a0c4..5f2eacc 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -333,7 +333,7 @@ def align_trials(self, units, data, label, event, sigma, end_event): if "spike_trial" in data: logging.info( f"\n> Aligning spike times and spike rate of {units} units to " - f"<{label.name}> trials." + f"<{label.name}> trials, from {event.name} to {end_event.name}." ) return self._get_aligned_trials( label, event, units=units, sigma=sigma, end_event=end_event, From 8690d9b56991e684516b32d2220c0f0795f6b956 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:52:01 +0100 Subject: [PATCH 622/658] get spike data only when it's necessary --- pixels/stream.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 5f2eacc..c953e1c 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -661,16 +661,11 @@ def sync_vr(self, vr_session): else: self._sync_vr(vr_session) + """ return None def _sync_vr(self, vr_session): - # get spike data - spike_data = self.session.find_file( - name=self.files["ap_raw"][self.stream_num], - copy=True, - ) - # get synchronised vr path synched_vr_path = vr_session.cache_dir + "synched/" +\ vr_session.name + "_vr_synched.h5" @@ -679,6 +674,12 @@ def _sync_vr(self, vr_session): synched_vr = file_utils.read_hdf5(synched_vr_path) logging.info("\n> synchronised vr loaded") except: + # get spike data + spike_data = self.session.find_file( + name=self.files["ap_raw"][self.stream_num], + copy=True, + ) + # get sync pulses sync_map = ioutils.read_bin(spike_data, 385, 384) syncs = signal.binarise(sync_map) From 9cf79ed6a08c2ee2e019cbcdf63d974b9317e725 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:53:15 +0100 Subject: [PATCH 623/658] no need to dropna cuz there will not be any --- pixels/pixels_utils.py | 17 +++++------------ pixels/stream.py | 2 -- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7d221e0..fdd4ae9 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2119,23 +2119,16 @@ def get_landmark_responsives(pos_fr, units, ons, offs): & (positions[:, None] <= offs) # get pre wall and trial mask - pre_wall = pos_fr.loc[ - position_mask[:, 0], : - ].dropna(axis=1, how="any") + pre_wall = pos_fr.loc[position_mask[:, 0], :] trials_pre_wall = pre_wall.columns.get_level_values( "trial" ).unique() - trial_mask = pos_fr.columns.get_level_values( - "trial" - ).isin(trials_pre_wall) # get mean & std of walls and landmark - landmark = pos_fr.loc[ - position_mask[:, 1], trial_mask - ].dropna(axis=1, how="any") - post_wall = pos_fr.loc[ - position_mask[:, 2], trial_mask - ].dropna(axis=1, how="any") + landmark = pos_fr.loc[position_mask[:, 1], :] + assert (landmark.columns.get_level_values("trial").unique() ==\ + trials_pre_wall).all() + post_wall = pos_fr.loc[position_mask[:, 2], :] pre_wall_mean = pre_wall.mean(axis=0) pre_wall_std = pre_wall.std(axis=0) diff --git a/pixels/stream.py b/pixels/stream.py index c953e1c..331c536 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -168,8 +168,6 @@ def _get_vr_positions(self, label, event, end_event): positions.columns = cols_with_starts positions = positions.sort_index(axis=1, ascending=[False, True]) positions.index.name = "time" - # drop trials with all nan - positions = positions.dropna(axis=1, how="all") return positions From fe420d5ec7f64199702954894f606f2ac1c0e878 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:53:39 +0100 Subject: [PATCH 624/658] exclude equal cuz we used np.ceil to get position --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index fdd4ae9..7ceab4a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2116,7 +2116,7 @@ def get_landmark_responsives(pos_fr, units, ons, offs): positions = pos_fr.index.to_numpy() # build mask for all positions position_mask = (positions[:, None] >= ons)\ - & (positions[:, None] <= offs) + & (positions[:, None] < offs) # get pre wall and trial mask pre_wall = pos_fr.loc[position_mask[:, 0], :] From 8d49251a1495fd21510bd7b526a6d05cecbe2dea Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:54:09 +0100 Subject: [PATCH 625/658] get default position bin size --- pixels/constants.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/constants.py b/pixels/constants.py index 7ba414c..004b2b6 100644 --- a/pixels/constants.py +++ b/pixels/constants.py @@ -30,3 +30,7 @@ BIG_CHUNKS = 1024 ALPHA = 0.05 + +# position bin sizes +POSITION_BIN = 1 # cm +BIG_POSITION_BIN = 10 # cm From cd494020d96bb6e7e38c87ce29ec7ea5afb2357e Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:54:33 +0100 Subject: [PATCH 626/658] increase dtype cuz there are more events now --- pixels/behaviours/virtual_reality.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index ed634c6..d9b35b6 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -145,8 +145,8 @@ class Events(IntFlag): class LabeledEvents(NamedTuple): """Return type: timestamps + bitfields for outcome & events.""" timestamps: np.ndarray # shape (N,) - outcome: np.ndarray # shape (N,) dtype uint32 - events: np.ndarray # shape (N,) dtype uint32 + outcome: np.ndarray # shape (N,) dtype uint64 + events: np.ndarray # shape (N,) dtype uint64 class WorldMasks(NamedTuple): @@ -206,8 +206,8 @@ def _extract_action_labels( at i """ N = len(data) - events_arr = np.zeros(N, dtype=np.uint32) - outcomes_arr = np.zeros(N, dtype=np.uint32) + events_arr = np.zeros(N, dtype=np.uint64) + outcomes_arr = np.zeros(N, dtype=np.uint64) # world index based events world_index_based_events = self._world_event_indices(data) From 53ae11a1c11da710d9b9d0d8f52a7117fd472416 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:55:06 +0100 Subject: [PATCH 627/658] make sure to convert event flag into uint64 to avoid ambiguity --- pixels/behaviours/virtual_reality.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index d9b35b6..e59fcd5 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -213,7 +213,7 @@ def _extract_action_labels( world_index_based_events = self._world_event_indices(data) for event, idx in world_index_based_events.items(): mask = self._get_index(data, idx.to_numpy()) - self._stamp_mask(events_arr, mask, event) + self._stamp_mask(events_arr, mask, np.uint64(event.value)) # get dark onset times for pre_dark_len labels dark_on_t = world_index_based_events[Events.dark_on] @@ -226,10 +226,14 @@ def _extract_action_labels( ) for event, idx in pos_based_events.items(): mask = self._get_index(data, idx.to_numpy()) - self._stamp_mask(events_arr, mask, event) + self._stamp_mask(events_arr, mask, np.uint64(event.value)) # sensors: lick rising‐edge - self._stamp_rising(events_arr, data.lick_detect.values, Events.licked) + self._stamp_rising( + events_arr, + data.lick_detect.values, + np.uint64(Events.licked), + ) # map trial outcomes outcome_map = self._build_outcome_map() @@ -249,7 +253,7 @@ def _extract_action_labels( if len(valve_events) > 0: for v_event, v_idx in valve_events.items(): valve_mask = self._get_index(data, v_idx) - self._stamp_mask(events_arr, valve_mask, v_event) + self._stamp_mask(events_arr, valve_mask, np.uint64(v_event)) # return typed arrays return LabeledEvents( From a72eed16d9ed9ad63055f428522bcbcc9b0cee87 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:56:02 +0100 Subject: [PATCH 628/658] make sure to get the exact position for the positional event --- pixels/behaviours/virtual_reality.py | 67 +++++++++++++++++++++------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/pixels/behaviours/virtual_reality.py b/pixels/behaviours/virtual_reality.py index e59fcd5..45e8501 100644 --- a/pixels/behaviours/virtual_reality.py +++ b/pixels/behaviours/virtual_reality.py @@ -465,40 +465,75 @@ def _first_post_mark(group_df, check_marks): landmark_idx = l // 2 + 1 + # NOTE: both landmark and wall boolean use < not <= because + # in our vr data we only get the exact value (e.g., 140.0) at + # starting position. # even idx on, odd idx off on_landmark = landmark - off_landmark = landmarks[l + 1] - - in_landmark = ( + on_landmarks = ( (df.position_in_tunnel >= on_landmark) & - (df.position_in_tunnel <= off_landmark) + (df.position_in_tunnel < on_landmark + 1) ) - - landmark_on = df[in_landmark].groupby("trial_count").apply( + landmark_on = df[on_landmarks].groupby("trial_count").apply( self._first_index ) - landmark_off = df[in_landmark].groupby("trial_count").apply( + + off_landmark = landmarks[l + 1] + off_landmarks = ( + (df.position_in_tunnel > off_landmark - 1) & + (df.position_in_tunnel < off_landmark) + ) + landmark_off = df[off_landmarks].groupby("trial_count").apply( self._last_index ) - masks[getattr(Events, f"landmark{landmark_idx}_on")] = landmark_on masks[getattr(Events, f"landmark{landmark_idx}_off")] = landmark_off - # <<< landmarks 1 to 5 <<< + + # even idx on, odd idx off + on_wall = walls[l] + on_walls = ( + (df.position_in_tunnel >= on_wall) & + (df.position_in_tunnel < on_wall + 1) + ) + wall_on = df[on_walls].groupby("trial_count").apply( + self._first_index + ) + + off_wall = walls[l + 1] + off_walls = ( + (df.position_in_tunnel > off_wall - 1) & + (df.position_in_tunnel < off_wall) + ) + wall_off = df[off_walls].groupby("trial_count").apply( + self._last_index + ) + + masks[getattr(Events, f"wall{landmark_idx}_on")] = wall_on + masks[getattr(Events, f"wall{landmark_idx}_off")] = wall_off + # <<< landmarks and wall 1 to 5 <<< # >>> reward zone >>> - in_zone = ( + zone_ons = ( df.position_in_tunnel >= session.reward_zone_start ) & ( - df.position_in_tunnel <= session.reward_zone_end + df.position_in_tunnel < session.reward_zone_start + 1 ) - in_zone_trials = df[in_zone].groupby("trial_count") - # first frame in reward zone - zone_on_t = in_zone_trials.apply(self._first_index) - masks[Events.reward_zone_on] = zone_on_t + zone_on_t = df[zone_ons].groupby("trial_count").apply( + self._first_index + ) + zone_offs = ( + df.position_in_tunnel > session.reward_zone_end - 1 + ) & ( + df.position_in_tunnel < session.reward_zone_end + ) # last frame in reward zone - zone_off_t = in_zone_trials.apply(self._last_index) + zone_off_t = df[zone_offs].groupby("trial_count").apply( + self._last_index + ) + + masks[Events.reward_zone_on] = zone_on_t masks[Events.reward_zone_off] = zone_off_t # <<< reward zone <<< From 54eee6184e222db8ebeb6afc337af4da5f683990 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:57:09 +0100 Subject: [PATCH 629/658] align to wall on/off events --- pixels/stream.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 331c536..c80626a 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1295,28 +1295,27 @@ def get_chance_positional_psd(self, units, label, event, sigma, end_event): @cacheable(cache_format="zarr") def get_landmark_responsives(self, units, label, sigma): - from vision_in_darkness.constants import landmarks + from vision_in_darkness.constants import landmarks, mid_walls from pixels.behaviours.virtual_reality import Events - # get landmark events - landmark_events = [ + + wall_events = np.array([ value for key, value in Events.__dict__.items() if not (key.startswith("__") or key.startswith("_") or callable(value) or isinstance(value, (classmethod, staticmethod))) - and "landmark" in key - ] - # exclude the black wall and the last landmark - NUM_LANDMARKS = len(landmark_events[2:-2]) // 2 - # get start and end events - start_events = landmark_events[1:-1][0::2][:-1] - end_events = landmark_events[1:-1][3::2] + and "wall" in key], + dtype=object, + ).reshape(-1, 2) - # get on & off of landmark and stack them - lms = landmarks[1:-2].reshape((-1, 2)) + # get start and end events, excluding the last landmark + start_events = wall_events[:, 0][:-1] + end_events = wall_events[:, 1][1:] - # get pre & post wall - pre_walls = np.column_stack([wall_on[:-1], wall_off[:-1]]) - post_walls = np.column_stack([wall_on[1:], wall_off[1:]]) + # get on & off of landmark and walls and stack them + lms = landmarks[1:-2].reshape((-1, 2)) + walls = mid_walls.reshape((-1, 2)) + pre_walls = walls[:-1, :] + post_walls = walls[1:, :] # group pre wall, landmarks, and post wall together chunks = np.stack([pre_walls, lms, post_walls], axis=1) From eb167ce0687783365a0660e5090225b7eadd9cd2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:57:21 +0100 Subject: [PATCH 630/658] use 1 index for landmarks to be consistent with virtual_reality.py cuz black wall is landmark 0 --- pixels/stream.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index c80626a..19a0505 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1312,6 +1312,7 @@ def get_landmark_responsives(self, units, label, sigma): end_events = wall_events[:, 1][1:] # get on & off of landmark and walls and stack them + landmark_names = np.arange(1, len(start_events)+1) lms = landmarks[1:-2].reshape((-1, 2)) walls = mid_walls.reshape((-1, 2)) pre_walls = walls[:-1, :] @@ -1334,7 +1335,8 @@ def get_landmark_responsives(self, units, label, sigma): end_event=end_events[l], # NOTE: ALWAYS the last arg )["pos_fr"] - all_contrasts[l], resps[l] = xut.get_landmark_responsives( + lm = landmark_names[l] + all_contrasts[lm], resps[lm] = xut.get_landmark_responsives( pos_fr=pos_fr, units=units, ons=ons[l, :], From f6db6326e8ebc7acea4b1d78f0a910914ab0e1cd Mon Sep 17 00:00:00 2001 From: amz_office Date: Sat, 11 Oct 2025 21:59:11 +0100 Subject: [PATCH 631/658] stuff --- pixels/stream.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 19a0505..11055dd 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -659,7 +659,6 @@ def sync_vr(self, vr_session): else: self._sync_vr(vr_session) - """ return None From be3c024927a28cc210c57662f8e58c36ff3e076e Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 14 Oct 2025 12:29:16 +0100 Subject: [PATCH 632/658] remove redundant --- pixels/pixels_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7ceab4a..d2f26b5 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2182,8 +2182,7 @@ def get_landmark_responsives(pos_fr, units, ons, offs): #formula=simple_model, unit_id=unit_id, ) - #print(unit_fit.summary()) - # step 4: check contrast at each start + # check contrast at each start unit_contrasts = start_contrasts_ols( fit=unit_fit, starts=starts, From 77a4613f353972651cbc1b71ede0c85e27f2a686 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 14 Oct 2025 12:29:37 +0100 Subject: [PATCH 633/658] skip if number of starts less than 2 --- pixels/pixels_utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d2f26b5..b3939ca 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2189,14 +2189,20 @@ def get_landmark_responsives(pos_fr, units, ons, offs): ) lm_contrasts[unit_id] = unit_contrasts - # positive responsive - if (unit_contrasts.coef > 0).all()\ - and (unit_contrasts.p_holm < ALPHA).all(): - responsives.loc[unit_id] = 1 - # negative responsive - if (unit_contrasts.coef < 0).all()\ - and (unit_contrasts.p_holm < ALPHA).all(): - responsives.loc[unit_id] = -1 + if len(starts) < 2: + logging.info( + f"\n> Skip testing this landmark cuz only start {starts[0]} " + "covers it." + ) + else: + # positive responsive + if (unit_contrasts.coef > 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id] = 1 + # negative responsive + if (unit_contrasts.coef < 0).all()\ + and (unit_contrasts.p_holm < ALPHA).all(): + responsives.loc[unit_id] = -1 contrasts = pd.concat( lm_contrasts, From b29ae23b72e812ed9da53abbdcb1ee10d02b5997 Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 14 Oct 2025 12:29:57 +0100 Subject: [PATCH 634/658] fix typo --- pixels/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/stream.py b/pixels/stream.py index 11055dd..df7415e 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -76,7 +76,7 @@ def _map_trials(self, label, event, end_event=None): ends = np.flatnonzero(events & end_event) if ("dark" in label.name) and\ - any(name in event.name for name in ["landmark" or "wall"]): + any(name in event.name for name in ["landmark", "wall"]): # get dark onset and offset edges dark_on = ((events & event.dark_on) != 0).astype(np.int8) dark_off = ((events & event.dark_off) != 0).astype(np.int8) From f5e5bbe7458e9ab5af26702d599155bf30a5226c Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 14 Oct 2025 21:01:46 +0100 Subject: [PATCH 635/658] load merged sorting analyser if there is one --- pixels/behaviours/base.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index e76cb2b..5b3f530 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1632,9 +1632,15 @@ def select_units( streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): sa_dir = self.find_file(stream_files["sorting_analyser"]) - # load sorting analyser - temp_sa = si.load(sa_dir) - # NOTE: this load gives warning when using temp_wh.dat to build + # load sorting analyser, load merged if there is one + merged_dir = sa_dir.parent / "merged" + if merged_dir.exists(): + sa_dir = merged_dir / "merged_sa.zarr" + temp_sa = si.load(sa_dir) + else: + temp_sa = si.load(sa_dir) + + # NOTE: si.load gives warning when using temp_wh.dat to build # sorting analyser, and to load sa it will be loading binary # recording obj, and that checks version From 9b5c3870c4553767dbf7c5b5eab25c3f48dd95ff Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 14 Oct 2025 21:02:01 +0100 Subject: [PATCH 636/658] add file with merge units group --- pixels/ioutils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 564e078..e6b2b4b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -176,6 +176,10 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["noisy_units"] = base_name.with_name( f"{session_name}_{probe_id}_noisy_units.yaml" ) + # mergeable units in curated units + pixels[stream_id]["mergeable_units"] = base_name.with_name( + f"{session_name}_{probe_id}_mergeable_units.yaml" + ) # old catgt data pixels[stream_id]["CatGT_ap_data"].append( From 225a195b66275f9d36a7e5d5e3d85e0809832edb Mon Sep 17 00:00:00 2001 From: amz_office Date: Tue, 14 Oct 2025 21:02:22 +0100 Subject: [PATCH 637/658] set default to thread for fedora cuz process does not work for now... --- pixels/configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pixels/configs.py b/pixels/configs.py index f0006c8..934019d 100644 --- a/pixels/configs.py +++ b/pixels/configs.py @@ -18,9 +18,9 @@ # set si job_kwargs job_kwargs = dict( - #pool_engine="thread", # instead of default "process" - pool_engine="process", - mp_context="fork", # linux, but does not work still on 2025 aug 12 + pool_engine="thread", # instead of default "process" + #pool_engine="process",# does not work on 2025 oct 14 + mp_context="fork", # linux #mp_context="spawn", # mac & win progress_bar=True, n_jobs=0.8, From 7e1f26e51968ecdcf2103f215097ab683ebf5248 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 15 Oct 2025 16:32:41 +0100 Subject: [PATCH 638/658] add merged sa and use it if exists --- pixels/behaviours/base.py | 27 ++++++++++++++++++--------- pixels/ioutils.py | 8 ++++++-- pixels/stream.py | 10 ++++++++-- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 5b3f530..a7178a5 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1507,9 +1507,17 @@ def _get_si_spike_times(self, units): streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - sa_dir = self.find_file(stream_files["sorting_analyser"]) + # find sorting analyser, use merged if there is one + merged_sa_dir = self.find_file( + stream_files["merged_sorting_analyser"] + ) + if merged_sa_dir: + sa_dir = merged_sa_dir + else: + sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser temp_sa = si.load_sorting_analyzer(sa_dir) + # select units sorting = temp_sa.sorting.select_units(units) sa = temp_sa.select_units(units) @@ -1631,15 +1639,16 @@ def select_units( streams = self.files["pixels"] for stream_num, (stream_id, stream_files) in enumerate(streams.items()): - sa_dir = self.find_file(stream_files["sorting_analyser"]) - # load sorting analyser, load merged if there is one - merged_dir = sa_dir.parent / "merged" - if merged_dir.exists(): - sa_dir = merged_dir / "merged_sa.zarr" - temp_sa = si.load(sa_dir) + # find sorting analyser, use merged if there is one + merged_sa_dir = self.find_file( + stream_files["merged_sorting_analyser"] + ) + if merged_sa_dir: + sa_dir = merged_sa_dir else: - temp_sa = si.load(sa_dir) - + sa_dir = self.find_file(stream_files["sorting_analyser"]) + # load sorting analyser + temp_sa = si.load_sorting_analyzer(sa_dir) # NOTE: si.load gives warning when using temp_wh.dat to build # sorting analyser, and to load sa it will be loading binary # recording obj, and that checks version diff --git a/pixels/ioutils.py b/pixels/ioutils.py index e6b2b4b..73512e1 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -125,8 +125,12 @@ def get_data_files(data_dir, session_name): pixels[stream_id]["detected_peaks"] = base_name.with_name( f"{base_name.stem}_detected_peaks.h5" ) - pixels[stream_id]["sorting_analyser"] = base_name.parent/\ - f"sorted_stream_{probe_id[-1]}/curated_sa.zarr" + sorted_stream_dir = base_name.parent / f"sorted_stream_{probe_id[-1]}" + pixels[stream_id]["sorting_analyser"] = sorted_stream_dir /\ + "curated_sa.zarr" + # if performed units merge, we have an updated sorting analyser + pixels[stream_id]["merged_sorting_analyser"] = sorted_stream_dir /\ + "merged_sa.zarr" # <<< spikeinterface cache <<< # depth info of probe diff --git a/pixels/stream.py b/pixels/stream.py index df7415e..1968d80 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -615,8 +615,14 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): def get_spike_times(self, units): units = units[self.stream_id] - # find sorting analyser path - sa_path = self.session.find_file(self.files["sorting_analyser"]) + # find sorting analyser, use merged if there is one + merged_sa_dir = self.find_file( + stream_files["merged_sorting_analyser"] + ) + if merged_sa_dir: + sa_dir = merged_sa_dir + else: + sa_dir = self.find_file(stream_files["sorting_analyser"]) # load sorting analyser temp_sa = si.load_sorting_analyzer(sa_path) # select units From fc717e07172fdf2bd998f0ad18b41906b7f54821 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 19:46:59 +0100 Subject: [PATCH 639/658] get chance of we do align units from trial start to end --- pixels/behaviours/base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index a7178a5..683fe04 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -1647,6 +1647,7 @@ def select_units( sa_dir = merged_sa_dir else: sa_dir = self.find_file(stream_files["sorting_analyser"]) + # load sorting analyser temp_sa = si.load_sorting_analyzer(sa_dir) # NOTE: si.load gives warning when using temp_wh.dat to build @@ -1828,6 +1829,16 @@ def align_trials( end_event=end_event, ) + if units.name == "all" and event.name == "trial_start"\ + and end_event.name == "trial_end": + stream.get_spike_chance( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + ) + return output if data == "motion_tracking" and not dlc_project: From 3a62cce56850965dff449fe052978790c8255a14 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 19:47:57 +0100 Subject: [PATCH 640/658] use path not dir --- pixels/stream.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index 1968d80..ac3e0b9 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -616,13 +616,13 @@ def get_spike_times(self, units): units = units[self.stream_id] # find sorting analyser, use merged if there is one - merged_sa_dir = self.find_file( - stream_files["merged_sorting_analyser"] + merged_sa_path = self.session.find_file( + self.files["merged_sorting_analyser"] ) - if merged_sa_dir: - sa_dir = merged_sa_dir + if merged_sa_path: + sa_path = merged_sa_path else: - sa_dir = self.find_file(stream_files["sorting_analyser"]) + sa_path = self.session.find_file(self.files["sorting_analyser"]) # load sorting analyser temp_sa = si.load_sorting_analyzer(sa_path) # select units From b7204b2782e0b076e49aae82c97755dd4e8192a6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:20:16 +0100 Subject: [PATCH 641/658] get max channel index --- pixels/pixels_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index b3939ca..79630b5 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -647,7 +647,8 @@ def _curate_sorting(sorting, recording, output): coords = sa.get_channel_locations() # get coordinates of max channel of each unit on probe, column 0 is # x-axis, column 1 is y-axis/depth, 0 at bottom-left channel. - max_chan_coords = coords[sa.channel_ids_to_indices(max_chan)] + max_chan_idx = sa.channel_ids_to_indices(max_chan) + max_chan_coords = coords[max_chan_idx] # set coordinates of max channel of each unit as a property of sorting sa.sorting.set_property( key="max_chan_coords", From 48285e40b2981525dd282ebd8490dcf402a4d214 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:20:54 +0100 Subject: [PATCH 642/658] rename to differentiate from template metrics rule --- pixels/pixels_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 79630b5..192f71a 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -661,9 +661,9 @@ def _curate_sorting(sorting, recording, output): # & amplitude_cutoff < 0.05 & sd_ratio < 1.5 & presence_ratio > 0.9\ # & snr > 1.1 & rp_contamination < 0.2 & firing_rate > 0.1" # use the ibl methods, but amplitude_cutoff rather than noise_cutoff - rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -40\ + qms_rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -40\ & presence_ratio > 0.9" - good_qms = qms.query(rule) + good_qms = qms.query(qms_rule) # TODO nov 26 2024 # wait till noise cutoff implemented and include that. # also see why sliding rp violation gives loads nan. From 238c14de80c95b0ff64938f371645e8ffaed63a5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:23:38 +0100 Subject: [PATCH 643/658] show removed bad unit ids --- pixels/pixels_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 192f71a..7bb6ddf 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -664,9 +664,17 @@ def _curate_sorting(sorting, recording, output): qms_rule = "snr > 1.1 & rp_contamination < 0.2 & amplitude_median <= -40\ & presence_ratio > 0.9" good_qms = qms.query(qms_rule) + logging.info( + "> quality metrics check removed " + f"{np.setdiff1d(sa.unit_ids, good_qms.index.values)}." + ) # TODO nov 26 2024 # wait till noise cutoff implemented and include that. # also see why sliding rp violation gives loads nan. + logging.info( + "> Template metrics check removed " + f"{np.setdiff1d(sa.unit_ids, good_tms.index.values)}." + ) # get unit ids curated_unit_ids = list(good_qms.index) # select curated From 5d18df6cff6df31b51702e8b8293356de40eb5a2 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:26:08 +0100 Subject: [PATCH 644/658] use template metrics to filter units --- pixels/pixels_utils.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 7bb6ddf..d3fe5ed 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -671,12 +671,41 @@ def _curate_sorting(sorting, recording, output): # TODO nov 26 2024 # wait till noise cutoff implemented and include that. # also see why sliding rp violation gives loads nan. + + # calculate template metrics + tms = spost.compute_template_metrics( + sa, + include_multi_channel_metrics=True, + ) + # remove noise based on waveform + tms_rule = "num_positive_peaks <= 2 & num_negative_peaks == 1 &\ + exp_decay > 0.01 & exp_decay < 0.1" # bombcell + #peak_to_valley > 0.00018 &\ + good_tms = tms.query(tms_rule) logging.info( "> Template metrics check removed " f"{np.setdiff1d(sa.unit_ids, good_tms.index.values)}." ) + + # get good units that passed quality metrics & template metrics + good_units = np.intersect1d(good_qms.index.values, good_tms.index.values) + good_unit_mask = np.isin(sa.unit_ids, good_units) + + # get template of each unit on its max channel + templates = sa.load_extension("templates").get_data() + unit_idx = sa.sorting.ids_to_indices(good_units) + max_chan_templates = templates[unit_idx, :, max_chan_idx[good_unit_mask]] + + # filter non somatic units by waveform analysis + soma_mask = filter_non_somatics( + sa.unit_ids, + max_chan_templates, + sa.sampling_frequency, + ) + soma_units = good_units[soma_mask] + # get unit ids - curated_unit_ids = list(good_qms.index) + curated_unit_ids = np.intersect1d(good_units, soma_units) # select curated curated_sorting = sa.sorting.select_units(curated_unit_ids) curated_sa = sa.select_units(curated_unit_ids) From 788e08ed93410157017b57ee4fa76b8f1a03e4f1 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:27:05 +0100 Subject: [PATCH 645/658] implement template metrics filter non somatic units --- pixels/pixels_utils.py | 100 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d3fe5ed..3c0e005 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -2249,3 +2249,103 @@ def get_landmark_responsives(pos_fr, units, ons, offs): ).droplevel("index") return contrasts, responsives + + +def filter_non_somatics(unit_ids, templates, sampling_freq): + # NOTE: no need to worry about multi-positive-peak templates, cuz we already + # threw them out + from scipy.signal import find_peaks + + # True means yes somatic, False mean non somatic + mask = np.zeros(len(templates), dtype=bool) + + # bombcell non somatic criteria + max_repo_peak_to_trough_ratio = 0.8 + + # height ratio to trough + heigh_ratio_to_trough = 0.15 #0.2 + # minimum width of the peak for detection + min_width = 0 + + ## minimum width of depo peak for filtering + #peak_width_ms = 0.1 # 0.07 + #min_depo_peak_width = int(peak_width_ms / 1000 * sampling_freq) + + # NOTE: since our data is high-pass filtered, the very transient peak before + # could be caused by that, its width does not tell us whether it is a + # somatic unit or not. DO NOT RELY ON PRE TROUGH PEAK TO DECIDE SOMATIC! + # so there are two ways to do this: + # 1. set the detection minimum to min_depo_peak_width like in + # spikeinterface, so that we just ignore those transient peaks and consider + # it as high-pass-filter artefact, has no weight in somatic decision; + # 2. using a very low threshold during detection like what we do here, and + # check the height ratio of the peak, if it exceeds our maximum, still + # consider it as non somatic, even if it is very transient. + + # maximum depo peak to repo peak ratio + max_depo_peak_to_repo_peak_ratio = 1.5 + # maximum depo peak to trough ratio + max_depo_peak_to_trough_ratio = 0.5 + + for t, template in enumerate(templates): + # get absolute maximum + template_max = np.max(np.abs(template)) + # minimum prominence of the peak + prominence = heigh_ratio_to_trough * template_max + # get trough index + trough_idx = np.argmin(template) + trough_height = np.abs(template)[trough_idx] + + # get positive peaks + peak_idx, peak_properties = find_peaks( + x=template, + prominence=prominence, + width=min_width, + ) + + assert len(peak_idx) <= 2, "why are there more than 2 peaks?" + + if not len(peak_idx) == 0: + # maximum positive peak is not larger than 80% of trough + if not np.abs(template[peak_idx][-1])\ + / trough_height < max_repo_peak_to_trough_ratio: + print( + f"> {unit_ids[t]} has positive peak larger than 80% trough" + ) + continue + + if len(peak_idx) > 1: + # if both peaks before or after trough, consider non somatic + if (peak_idx > trough_idx).all()\ + or (peak_idx < trough_idx).all(): + print(f"> {unit_ids[t]} both peaks on one side") + continue + + # check compare bases + # if pre depolarisation peak + assert peak_properties["right_bases"][0] == trough_idx + # if post repolarisation peak + assert peak_properties["left_bases"][1] == trough_idx + + peak_heights = template[peak_idx] + + # check depo peak height is + if not peak_heights[0] / trough_height\ + < max_depo_peak_to_trough_ratio: + print(f"> {unit_ids[t]} depo peak is half the size of trough") + continue + + # check height ratio of peaks, make sure the depolarisation peak is + # NOT much bigger than repolarisation peak + if not peak_heights[0] / peak_heights[1]\ + < max_depo_peak_to_repo_peak_ratio: + print( + f"> {unit_ids[t]} depo peak is 1.5 times larger than " + "the repo" + ) + continue + + # yes somatic if no positive peaks or pass all checks + mask[t] = True + + return mask From f73f766b279c47085866c30d506bf66911b6d50a Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:47:13 +0100 Subject: [PATCH 646/658] get chance with aligning in base.py --- pixels/stream.py | 51 ------------------------------------------------ 1 file changed, 51 deletions(-) diff --git a/pixels/stream.py b/pixels/stream.py index ac3e0b9..cf77b9b 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -366,39 +366,6 @@ def _get_aligned_trials( ) output["positions"] = positions - if units.name == "all" and (label.name == "light" or "dark"): - self.get_spike_chance( - #stream_files=stream_files, - spiked=stacked_spiked, - sigma=sigma, - ) - #else: - # # access chance data if we only need part of the units - # self.get_spike_chance( - # sample_rate=self.SAMPLE_RATE, - # positions=all_pos, - # sigma=sigma, - # ) - # assert 0 - - # get trials horizontally stacked spiked - #spiked = ioutils.reindex_by_longest( - # dfs=stacked_spiked, - # level="trial", - # return_format="dataframe", - #) - #fr = ioutils.reindex_by_longest( - # dfs=trials_fr, - # level="trial", - # idx_names=["trial", "time"], - # col_names=["unit"], - # return_format="dataframe", - #) - - #output["spiked"] = spiked - #output["fr"] = fr - #output["positions"] = positions - return output @@ -558,24 +525,6 @@ def _get_aligned_events(self, label, event, units=None, sigma=None): stacked_spiked.index.names = ["trial", "time"] stacked_spiked.columns.names = ["unit"] - # TODO apr 21 2025: - # save spike chance only if all units are selected, else - # only index into the big chance array and save into zarr - #if units.name == "all" and (label == 725 or 1322): - # self.save_spike_chance( - # stream_files=stream_files, - # spiked=stacked_spiked, - # sigma=sigma, - # ) - #else: - # # access chance data if we only need part of the units - # self.get_spike_chance( - # sample_rate=self.SAMPLE_RATE, - # positions=all_pos, - # sigma=sigma, - # ) - # assert 0 - # get trials horizontally stacked spiked spiked = ioutils.reindex_by_longest( dfs=stacked_spiked, From 84b3cd9f0d3d65394fe6255df1be4a204c9c4bd8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Fri, 17 Oct 2025 22:49:45 +0100 Subject: [PATCH 647/658] add todo --- pixels/stream.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pixels/stream.py b/pixels/stream.py index cf77b9b..ee00b97 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -370,6 +370,17 @@ def _get_aligned_trials( def _get_aligned_events(self, label, event, units=None, sigma=None): + # TODO oct 17 2025: + # use _get_vr_spikes?? + # get spikes and firing rate + _, output = self._get_vr_spikes( + units, + label, + event, + sigma, + end_event, + ) + # get synched pixels stream with vr and action labels synched_vr, action_labels = self.get_synched_vr() From 01c4946b837885bac6f20e7d3a4492ad2476445c Mon Sep 17 00:00:00 2001 From: az_delta Date: Sun, 19 Oct 2025 13:42:16 +0100 Subject: [PATCH 648/658] allow of_dates from mice in nested list --- pixels/ioutils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pixels/ioutils.py b/pixels/ioutils.py index 73512e1..03f153b 100644 --- a/pixels/ioutils.py +++ b/pixels/ioutils.py @@ -474,7 +474,7 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): sessions = {} raw_dir = data_dir / "raw" - for mouse in mouse_ids: + for m, mouse in enumerate(mouse_ids): mouse_sessions = sorted(list(raw_dir.glob(f"*{mouse}"))) if not mouse_sessions: @@ -490,6 +490,8 @@ def get_sessions(mouse_ids, data_dir, meta_dir, session_date_fmt, of_date=None): if of_date is not None: if isinstance(of_date, str): date_list = [of_date] + elif is_nested_list(of_date): + date_list = of_date[m] else: date_list = of_date @@ -847,6 +849,13 @@ def is_nested_dict(d): return any(isinstance(v, dict) for v in d.values()) +def is_nested_list(ls): + """ + Returns True if at least one value in list d is a list. + """ + return any(isinstance(item, list) for item in ls) + + def save_index_to_frame(df, path): idx_df = df.index.to_frame(index=False) write_hdf5( From 32970cc748aed864214b6b5f532d9b478ffabfb9 Mon Sep 17 00:00:00 2001 From: az_delta Date: Wed, 22 Oct 2025 17:23:24 +0100 Subject: [PATCH 649/658] change name of arg to em_method (estimate_motion_method) cuz the the different bit for ap and lfp is at motion estimation step --- pixels/pixels_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 3c0e005..c3d846b 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -161,11 +161,11 @@ def CAR(rec, dtype=np.int16): return car -def correct_lfp_motion(rec, mc_method="dredge_lfp"): +def correct_lfp_motion(rec, em_method="dredge_lfp"): raise NotImplementedError("> Not implemented.") -def correct_ap_motion(rec, mc_method="dredge_ap"): +def correct_ap_motion(rec, em_method="dredge_ap"): """ Correct motion of recording. @@ -180,11 +180,13 @@ def correct_ap_motion(rec, mc_method="dredge_ap"): === None """ - logging.info(f"\n> Correcting motion with {mc_method}.") + logging.info(f"\n> Correcting motion with {em_method}.") + + mc_method = em_method.split("_")[0] # reduce spatial window size for four-shank estimate_motion_kwargs = { - "method": "dredge_ap", + "method": f"{em_method}", "win_step_um": 100, "win_margin_um": -150, "verbose": True, From 499ca4da969fed875fa89b684a8d980570952fe2 Mon Sep 17 00:00:00 2001 From: az_delta Date: Wed, 22 Oct 2025 17:49:50 +0100 Subject: [PATCH 650/658] input motion correction method, then define motion estimation method --- pixels/behaviours/base.py | 2 +- pixels/experiment.py | 2 +- pixels/pixels_utils.py | 21 ++++++++++++++++----- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 683fe04..35f4f01 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -525,7 +525,7 @@ def process_behaviour(self): logging.info("\n> Done!") - def correct_ap_motion(self, mc_method="dredge_ap"): + def correct_ap_motion(self, mc_method="dredge"): """ Correct motion of recording. diff --git a/pixels/experiment.py b/pixels/experiment.py index abb3663..df805c3 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -132,7 +132,7 @@ def extract_ap(self): .format(session.name, i + 1, len(self.sessions))) session.extract_ap() - def sort_spikes(self, mc_method="dredge_ap"): + def sort_spikes(self, mc_method="dredge"): """ Extract the spikes from raw spike data for all sessions. """ for i, session in enumerate(self.sessions): logging.info( diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index c3d846b..9b74546 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -161,18 +161,24 @@ def CAR(rec, dtype=np.int16): return car -def correct_lfp_motion(rec, em_method="dredge_lfp"): +def correct_lfp_motion(rec, mc_method="dredge"): + if mc_method == "dredge": + em_method = mc_method+"_lfp" + else: + em_method = spre.motion.motion_options_preset[mc_method][ + "estimate_motion_kwargs" + ]["method"] raise NotImplementedError("> Not implemented.") -def correct_ap_motion(rec, em_method="dredge_ap"): +def correct_ap_motion(rec, mc_method="dredge"): """ Correct motion of recording. params === mc_method: str, motion correction method. - Default: "dredge_ap". + Default: "dredge". (as of jan 2025, dredge performs better than ks motion correction.) "ks": let kilosort do motion correction. @@ -180,9 +186,14 @@ def correct_ap_motion(rec, em_method="dredge_ap"): === None """ - logging.info(f"\n> Correcting motion with {em_method}.") + logging.info(f"\n> Correcting motion with {mc_method}.") - mc_method = em_method.split("_")[0] + if mc_method == "dredge": + em_method = mc_method+"_ap" + else: + em_method = spre.motion.motion_options_preset[mc_method][ + "estimate_motion_kwargs" + ]["method"] # reduce spatial window size for four-shank estimate_motion_kwargs = { From be20291c5a0115c3bdb94022bb0f056676f2b060 Mon Sep 17 00:00:00 2001 From: az_delta Date: Wed, 22 Oct 2025 17:50:37 +0100 Subject: [PATCH 651/658] show surface depth --- pixels/pixels_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9b74546..2381dc4 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -138,7 +138,10 @@ def _preprocess_raw(rec, surface_depth, faulty_channels): # remove channels outside by using identified brain surface depths outside_chan_ids = chan_ids[chan_depths > surface_depth] rec_clean = rec_removed.remove_channels(outside_chan_ids) - print(f"\t\t> Removed {outside_chan_ids.size} outside channels.") + print( + f"\t\t> Removed {outside_chan_ids.size} outside channels + above {surface_depth}um." + ) return rec_clean From 9effe46e9000dcb5e459f9394227ca4953d0c8d6 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 17:58:09 +0100 Subject: [PATCH 652/658] add double quotes --- pixels/pixels_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 2381dc4..9aea226 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -139,8 +139,8 @@ def _preprocess_raw(rec, surface_depth, faulty_channels): outside_chan_ids = chan_ids[chan_depths > surface_depth] rec_clean = rec_removed.remove_channels(outside_chan_ids) print( - f"\t\t> Removed {outside_chan_ids.size} outside channels - above {surface_depth}um." + f"\t\t> Removed {outside_chan_ids.size} outside channels " + f"above {surface_depth}um." ) return rec_clean From ebbda0102cfdcd27a16c4ad06745eaad363362f5 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 17:58:44 +0100 Subject: [PATCH 653/658] get binned chance data --- pixels/behaviours/base.py | 42 +++++++++++++++++++++++++++++++++++++++ pixels/experiment.py | 13 ++++++++++++ pixels/stream.py | 29 +++++++++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/pixels/behaviours/base.py b/pixels/behaviours/base.py index 35f4f01..3583f77 100644 --- a/pixels/behaviours/base.py +++ b/pixels/behaviours/base.py @@ -2580,6 +2580,48 @@ def get_binned_trials( return binned + def get_binned_chance( + self, label, event, sigma, end_event, time_bin, pos_bin, + ): + """ + This function saves binned data in the format that Andrew wants: + trials * units * temporal bins (ms) + + time_bin: str | None + For VR behaviour, size of temporal bin for spike rate data. + + pos_bin: int | None + For VR behaviour, size of positional bin for position data. + """ + units = self.select_units(name="all") + + binned = {} + streams = self.files["pixels"] + for stream_num, (stream_id, stream_files) in enumerate(streams.items()): + stream = Stream( + stream_id=stream_id, + stream_num=stream_num, + files=stream_files, + session=self, + ) + + logging.info( + f"\n> Getting binned <{label.name}> chance from {stream_id}." + ) + binned[stream_id] = stream.get_binned_chance( + units=units, + label=label, + event=event, + sigma=sigma, + end_event=end_event, + time_bin=time_bin, + pos_bin=pos_bin, + ) + + return binned + + + def get_chance_positional_psd(self, units, label, event, sigma, end_event): streams = self.files["pixels"] chance_psd = {} diff --git a/pixels/experiment.py b/pixels/experiment.py index df805c3..f7f910b 100644 --- a/pixels/experiment.py +++ b/pixels/experiment.py @@ -615,3 +615,16 @@ def get_spike_chance(self, *args, **kwargs): chance[name] = session.get_spike_chance(*args, **kwargs) return chance + + + def get_binned_chance(self, *args, **kwargs): + """ + Get binned chance firing rate and spike count for aligned vr trials. + Check behaviours.base.Behaviour.get_binned_chance for usage information. + """ + binned = {} + for i, session in enumerate(self.sessions): + name = session.name + binned[name] = session.get_binned_chance(*args, **kwargs) + + return binned diff --git a/pixels/stream.py b/pixels/stream.py index ee00b97..d039068 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1220,6 +1220,35 @@ def get_spike_chance(self, units, label, event, sigma, end_event, return None + @cacheable(cache_format="zarr") + def get_binned_chance( + self, units, label, event, sigma, end_event, time_bin, pos_bin, + ): + # get array name and path + name_parts = [self.session.name, self.probe_id, label.name, units.name, + "shuffled", time_bin, f"{pos_bin}cm.npz"] + file_name = "_".join(p for p in name_parts) + arr_path = self.processed / file_name + + # get chance data + chance_data = self.get_spike_chance( + units, + label, + event, + sigma, + end_event, + ) + # bin chance data + binned_chance = xut.bin_chance_spikes( + chance_data, + time_bin, + pos_bin, + arr_path, + ) + + return binned_chance + + def _get_chance_args(self, label, event, sigma, end_event): probe_id = self.stream_id[:-3] name = self.session.name From e3e738d1ed589d80d351ba30fd6d9c29f46421cc Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 18:07:30 +0100 Subject: [PATCH 654/658] save chance data in zarr not memmap --- pixels/pixels_utils.py | 325 +---------------------------------------- 1 file changed, 3 insertions(+), 322 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 9aea226..6721110 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -40,7 +40,6 @@ from pixels.decorators import _df_to_zarr_via_xarray, _df_from_zarr_via_xarray from common_utils import math_utils -from common_utils.file_utils import init_memmap, read_hdf5 def load_raw(paths, stream_id): """ @@ -819,72 +818,6 @@ def _permute_spikes_n_convolve_fr(array, sigma, sample_rate): return random_spiked, random_fr -def _chance_worker(i, sigma, sample_rate, spiked_shape, chance_data_shape, - spiked_memmap_path, fr_memmap_path, concat_spiked_path): - """ - Worker that computes one set of spiked and fr values. - - params - === - i: index of current repeat. - - sigma: int/float, time in millisecond of sigma of gaussian kernel for firing - rate convolution. - - sample_rate: float/int, sampling rate of signal. - - spiked_shape: tuple, shape of spike boolean to initiate memmap. - - chance_data_shape: tuple, shape of chance data. - - spiked_memmap_path: - - fr_memmap_path: - - return - === - """ - logging.info(f"\nProcessing repeat {i}...") - # open readonly memmap - spiked = init_memmap( - path=concat_spiked_path, - shape=spiked_shape, - dtype=np.int16, - overwrite=False, - readonly=True, - ) - - # init appendable memmap - chance_spiked = init_memmap( - path=spiked_memmap_path, - shape=chance_data_shape, - dtype=np.int16, - overwrite=False, - readonly=False, - ) - # init chance firing rate memmap - chance_fr = init_memmap( - path=fr_memmap_path, - shape=chance_data_shape, - dtype=np.float32, - overwrite=False, - readonly=False, - ) - - # get permuted data - c_spiked, c_fr = _permute_spikes_n_convolve_fr(spiked, sigma, sample_rate) - - chance_spiked[..., i] = c_spiked - chance_fr[..., i] = c_fr - # write to disk - chance_spiked.flush() - chance_fr.flush() - - logging.info(f"\nRepeat {i} finished.") - - return None - - def _worker_write_repeat( i, zarr_path, sigma, sample_rate, shm_name, shape, dtype_str, ): @@ -1088,261 +1021,17 @@ def save_spike_chance_zarr( return None -def save_spike_chance(spiked_memmap_path, fr_memmap_path, spiked_df_path, - fr_df_path, sigma, sample_rate, repeats=100, spiked=None, - spiked_shape=None, concat_spiked_path=None): - assert 0 - if fr_df_path.exists(): - # save spike chance data if does not exists - _save_spike_chance( - spiked_memmap_path, fr_memmap_path, spiked_df_path, fr_df_path, sigma, - sample_rate, repeats, spiked, spiked_shape, - concat_spiked_path) - else: - logging.info(f"\n> Spike chance already saved at {fr_df_path}, continue.") - - return None + ) -def _save_spike_chance(spiked_memmap_path, fr_memmap_path, sigma, sample_rate, - repeats, spiked, spiked_shape, concat_spiked_path): - """ - Implementation of saving chance level spike data. - """ + positions=trial_pos, - # save spiked to memmap if not yet - # TODO apr 9 2025: if i have temp_spiked, how to get its shape? do i need - # another input arg??? this is to run it again without get the concat spiked - # again... - if spiked is None: - assert concat_spiked_path.exists() - assert spiked_shape is not None - else: - concat_spiked_path = spiked_memmap_path.parent/"temp_spiked.bin" - spiked_memmap = init_memmap( - path=concat_spiked_path, - shape=spiked.shape, - dtype=np.int16, - overwrite=True, - readonly=False, ) - spiked_memmap[:] = spiked.values - spiked_memmap.flush() - del spiked_memmap - - # get spiked data shape - spiked_shape = spiked.shape - - # get export data shape - d_shape = spiked_shape + (repeats,) - # TODO apr 9 2025 save dshape to json - #with open(shape_json, "w") as f: - #json.dump(shape, f, indent=4) - - if not fr_memmap_path.exists(): - # init chance spiked memmap - chance_spiked = init_memmap( - path=spiked_memmap_path, - shape=d_shape, - dtype=np.int16, - overwrite=True, - readonly=False, - ) - # init chance firing rate memmap - chance_fr = init_memmap( - path=fr_memmap_path, - shape=d_shape, - dtype=np.float32, - overwrite=True, - readonly=False, ) - # write to disk - chance_spiked.flush() - chance_fr.flush() - del chance_spiked, chance_fr - - # Set up the process pool to run the worker in parallel. - with ProcessPoolExecutor() as executor: - # Submit jobs for each repeat. - futures = [] - for i in range(repeats): - future = executor.submit( - _chance_worker, - i=i, - sigma=sigma, - sample_rate=sample_rate, - spiked_shape=spiked_shape, - chance_data_shape=d_shape, - spiked_memmap_path=spiked_memmap_path, - fr_memmap_path=fr_memmap_path, - concat_spiked_path=concat_spiked_path, - ) - futures.append(future) - - # As each future completes, assign the results into the memmap. - for future in concurrent.futures.as_completed(futures): - future.result() - else: - logging.info( - "\n> Memmaps already created, only need to convert into " - "dataframes and save." - ) - - # convert it to dataframe and save it - #save_chance( - # orig_idx=spiked.index, - # orig_col=spiked.columns, - # spiked_memmap_path=spiked_memmap_path, - # fr_memmap_path=fr_memmap_path, - # spiked_df_path=spiked_df_path, - # fr_df_path=fr_df_path, - # d_shape=d_shape, - #) - #logging.info(f"\n> Chance data saved to {fr_df_path}.") - - return None - - -def _convert_to_df(orig_idx, orig_col, memmap_path, df_path, d_shape, d_type, - name): - """ - Convert - - orig_idx, - orig_col, - memmap_path - df_path, - d_shape - d_type - name - """ - # NOTE: shape of memmap is `concatenated trials frames * units * repeats`, - # saved df has outer most level being `repeat`, then `unit`, and all trials - # are stacked vertically. - # to later use it for analysis, go into each repeat, and do - # `repeat_df.unstack(level='trial', sort=False)` to get the same structure as - # data. - - # init readonly chance memmap - chance_memmap = init_memmap( - path=memmap_path, - shape=d_shape, - dtype=d_type, - overwrite=False, - readonly=True, - ) - # copy to cpu - c_spiked = chance_memmap.copy() - # reshape to 2D - c_spiked_reshaped = c_spiked.reshape(d_shape[0], d_shape[1] * d_shape[2]) - del c_spiked - # create hierarchical index - col_idx = pd.MultiIndex.from_product( - [np.arange(d_shape[2]), orig_col], - names=["repeat", "unit"], + # save np array, for andrew ) - - # create df - df = pd.DataFrame(c_spiked_reshaped, columns=col_idx) - # use the original index - df.index = orig_idx - - # write h5 to disk - write_hdf5( - path=df_path, - df=df, - key=name, - mode="w", - ) - del df - - return None - - -def save_chance(orig_idx, orig_col, spiked_memmap_path, fr_memmap_path, - spiked_df_path, fr_df_path, d_shape): - """ - Saving chance data to dataframe. - - params - === - orig_idx: pandas - """ - logging.info(f"\n> Saving chance data...") - - # get chance spiked df - _convert_to_df( - orig_idx=orig_idx, - orig_col=orig_col, - memmap_path=spiked_memmap_path, - df_path=spiked_df_path, - d_shape=d_shape, - d_type=np.int16, - name="spiked", - ) - # get chance fr df - _convert_to_df( - orig_idx=orig_idx, - orig_col=orig_col, - memmap_path=fr_memmap_path, - df_path=fr_df_path, - d_shape=d_shape, - d_type=np.float32, - name="fr", - ) - - return None - - -def get_spike_chance(sample_rate, positions, spiked_memmap_path, fr_memmap_path, - memmap_shape_path, idx_path, col_path): - if not fr_memmap_path.exists(): - raise PixelsError("\nHave you saved spike chance data yet?") - else: - fr_chance, idx, cols = _get_spike_chance( - sample_rate, positions, spiked_memmap_path, fr_memmap_path, - memmap_shape_path, idx_path, col_path) - return fr_chance, idx, cols - - return None - - -def _get_spike_chance(sample_rate, positions, spiked_memmap_path, - fr_memmap_path, memmap_shape_path, idx_path, col_path): - - # TODO apr 9 2025: - # i do not need to save shape to file, all i need is unit count, repeat, - # so i load memmap without defining shape, then directly np.reshape(memmap, - # (-1, count, repeat))! - - with open(memmap_shape_path, "r") as f: - shape_data = json.load(f) - shape_list = shape_data.get("dshape", []) - d_shape = tuple(shape_list) - - spiked_chance = init_memmap( - path=spiked_memmap_path, - shape=d_shape, - dtype=np.int16, - overwrite=False, - readonly=True, - ) - - idx_df = read_hdf5(idx_path, key="multiindex") - idx = pd.MultiIndex.from_frame(idx_df) - trials = idx_df["trial"].unique() - col_df = read_hdf5(col_path, key="cols") - cols = pd.Index(col_df["unit"]) - - return fr_chance, idx, cols - assert 0 - - # TODO jun 16 2025: - # have a separate func for binning chance data - binned_shuffle = {} - temp = {} # TODO apr 3 2025: implement multiprocessing here! # get each repeat and create df for r in range(d_shape[-1]): @@ -1364,10 +1053,6 @@ def _get_spike_chance(sample_rate, positions, spiked_memmap_path, binned_shuffle[r] = reindex_by_longest( dfs=temp[r], return_format="array", - ) - # concat - binned_shuffle_counts = np.stack( - list(binned_shuffle.values()), axis=-1, ) shuffled_counts = { @@ -1388,10 +1073,6 @@ def _get_spike_chance(sample_rate, positions, spiked_memmap_path, # readonly=True, # ) - # TODO apr 2 2025: - # for fr chance, use memmap, go to each repeat, unstack, bin, then save it - # to .npz for andrew - pass def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, From 1dc379c8241c861d378c7bb32a3fe1f6792bcea0 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 18:08:05 +0100 Subject: [PATCH 655/658] bin chance given chance is zarr not memmap --- pixels/pixels_utils.py | 130 +++++++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 36 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 6721110..d24f22e 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1021,58 +1021,116 @@ def save_spike_chance_zarr( return None +def bin_chance_spikes(chance_data, time_bin, pos_bin, arr_path): + # extract data from chance + chance_spiked = chance_data["chance_spiked"] + chance_fr = chance_data["chance_fr"] + REPEATS = chance_spiked.shape[-1] + + # get index and columns to reconstruct df + spiked = chance_data["spiked"].dropna(axis=0, how="all") + idx = spiked.index + cols = spiked.columns + trial_ids = spiked.index.get_level_values("trial").unique() + # get positions + positions = chance_data["positions"].dropna(axis=1, how="all") + + count_arrs = {} + fr_arrs = {} + count_dfs = {} + fr_dfs = {} + temp_spiked = {} + temp_fr = {} + for repeat in range(REPEATS): + r_fr = pd.DataFrame( + chance_fr[:, :, repeat], + index=idx, + columns=cols, + ) + r_spiked = pd.DataFrame( + chance_spiked[:, :, repeat], + index=idx, + columns=cols, ) + temp_spiked[repeat] = {} + temp_fr[repeat] = {} + for trial in trial_ids: + trial_pos = positions.xs(trial, level="trial", axis=1).dropna() + counts = r_spiked.xs(trial, level="trial", axis=0) + fr = r_fr.xs(trial, level="trial", axis=0) + # bin fr + temp_fr[repeat][trial] = bin_vr_trial( + data=fr, + positions=trial_pos, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="mean", # fr + ) + # bin spiked + temp_spiked[repeat][trial] = bin_vr_trial( + data=counts, positions=trial_pos, + sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + time_bin=time_bin, + pos_bin=pos_bin, + bin_method="sum", # spike count + ) + fr_dfs[repeat] = pd.concat( + temp_fr[repeat], + axis=0, ) + count_dfs[repeat] = pd.concat( + temp_spiked[repeat], + axis=0, ) + # np array for andrew + fr_arrs[repeat] = reindex_by_longest( + dfs=temp_fr[repeat], + return_format="array", + ) + count_arrs[repeat] = reindex_by_longest( + dfs=temp_spiked[repeat], + return_format="array", + ) # save np array, for andrew + arr_count_output = np.stack( + list(count_arrs.values()), + axis=-1, + dtype=np.float32, ) - # TODO apr 3 2025: implement multiprocessing here! - # get each repeat and create df - for r in range(d_shape[-1]): - shuffled = spiked_chance[:, :, r] - # create df - df = pd.DataFrame(shuffled, index=idx, columns=cols) - temp[r] = {} - for t in trials: - counts = df.xs(t, level="trial", axis=0) - trial_pos = positions.loc[:, t].dropna() - temp[r][t] = bin_vr_trial( - counts, - trial_pos, - sample_rate, - time_bin, - pos_bin, - bin_method="sum", - ) - binned_shuffle[r] = reindex_by_longest( - dfs=temp[r], - return_format="array", + arr_fr_output = np.stack( + list(fr_arrs.values()), axis=-1, + dtype=np.float32, ) - shuffled_counts = { - "count": binned_shuffle_counts[:, :-2, ...], - "pos": binned_shuffle_counts[:, -2:, ...], + arrs = { + "count": arr_fr_output[:, :-2, ...], + "fr": arr_count_output[:, :-2, ...], + "pos": arr_fr_output[:, -2:, ...], } - #count_path='/home/amz/running_data/npx/interim/20240812_az_VDCN09/20240812_az_VDCN09_imec0_light_all_spike_counts_shuffled_200ms_10cm.npz' - count_path='/home/amz/running_data/npx/interim/20240812_az_VDCN09/20240812_az_VDCN09_imec0_dark_all_spike_counts_shuffled_200ms_10cm.npz' + np.savez_compressed(arr_path, **arrs) - np.savez_compressed(count_path, **shuffled_counts) - assert 0 - - # fr_chance = _get_spike_chance( - # path=fr_memmap_path, - # shape=d_shape, - # dtype=np.float32, - # overwrite=False, - # readonly=True, - # ) + # output df + df_fr = pd.concat( + fr_dfs, + axis=0, + names=["repeat", "trial", "bin_time"], + ).iloc[:, :-2] + temp = pd.concat( + count_dfs, + axis=0, + names=["repeat", "trial", "bin_time"], + ) + df_spiked = temp.iloc[:, :-2] + df_pos = temp.iloc[:, -2:] + return {"spiked": df_spiked, "fr": df_fr, "positions": df_pos} def bin_vr_trial(data, positions, sample_rate, time_bin, pos_bin, From 833d2deeb66e532a717998f95afac0c1afe66b0b Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 18:08:55 +0100 Subject: [PATCH 656/658] do not use welch by default it smooth too much --- pixels/pixels_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index d24f22e..231534d 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1426,7 +1426,7 @@ def _get_vr_positional_neural_data(positions, data_type, data): def get_spatial_psd(pos_fr): def _compute_psd(col): x = col.dropna().values.squeeze() - f, psd = math_utils.estimate_power_spectrum(x, use_welch=True) + f, psd = math_utils.estimate_power_spectrum(x, use_welch=False) # remove 0 to avoid infinity f = f[1:] psd = psd[1:] From 3e4143bb3164aad71c167ef48a18c3aa5eedce99 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 18:15:24 +0100 Subject: [PATCH 657/658] use implementation of df zarr in decorators --- pixels/pixels_utils.py | 135 +---------------------------------------- 1 file changed, 1 insertion(+), 134 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 231534d..828956f 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -37,7 +37,7 @@ from pixels.error import PixelsError from pixels.configs import * from pixels.constants import * -from pixels.decorators import _df_to_zarr_via_xarray, _df_from_zarr_via_xarray +from pixels.decorators import _df_to_zarr_via_xarray from common_utils import math_utils @@ -911,12 +911,6 @@ def save_spike_chance_zarr( compressor=compressor, mode="w", ) - #_write_df_as_zarr( - # root, - # spiked, - # group_name="spiked", - # compressor=compressor, - #) else: root.create_dataset( "spiked", @@ -964,12 +958,6 @@ def save_spike_chance_zarr( compressor=compressor, mode="w", ) - #_write_df_as_zarr( - # root, - # positions, - # group_name="positions", - # compressor=compressor, - #) else: root.create_dataset( "positions", @@ -1551,127 +1539,6 @@ def notch_freq(rec, freq, bw=4.0): return notched -def _write_df_as_zarr( - root, # zarr.hierarchy.Group - df: pd.DataFrame, - group_name: str, - *, - compressor=None, -): - row_prefix = "row" - col_prefix = "col" - - # Remove any existing node (array or group) with this name - if group_name in root: - del root[group_name] - - # Ensure all index/column level names are defined - if isinstance(df.index, pd.MultiIndex): - row_names = _default_names(list(df.index.names), row_prefix) - else: - row_names = [df.index.name or f"{row_prefix}0"] - - if isinstance(df.columns, pd.MultiIndex): - col_names = _default_names(list(df.columns.names), col_prefix) - else: - col_names = [df.columns.name or f"{col_prefix}0"] - - # Stack ALL column levels to move them into the row index; result index levels = row_names + col_names - series = df.stack(col_names, future_stack=True) # Series with MultiIndex index - - # Build DataArray (dims are level names of the Series index, in order) - da = xr.DataArray.from_series(series).rename("values") - ds = da.to_dataset() - - # Mark attrs for round-trip (which dims belong to rows vs columns) - ds.attrs["__via"] = "pd_df_any_mi" - ds.attrs["__row_dims__"] = row_names - ds.attrs["__col_dims__"] = col_names - - # check size to determine chunking - chunking = {} - for name, size in ds.sizes.items(): - if size > BIG_CHUNKS: - chunking[name] = BIG_CHUNKS - else: - chunking[name] = SMALL_CHUNKS - - ds = ds.chunk(chunking) - - # compressor & object codec - encoding = { - "values": { - "compressor": compressor, - "chunks": tuple(chunking.values()), - } - } - if ds["values"].dtype == object and VLenUTF8 is not None: - encoding["values"]["object_codec"] = VLenUTF8() - - # Ensure coords are writable (handle object/string coords) - # If VLenUTF8 is available, set encoding for object coords; otherwise cast - # to str - for cname, coord in ds.coords.items(): - if coord.dtype == object: - if VLenUTF8 is not None: - encoding[cname] = { - "object_codec": VLenUTF8(), - "compressor": compressor, - } - else: - ds = ds.assign_coords({cname: coord.astype(str)}) - - # Write into a subgroup under the same store - ds.to_zarr( - store=root.store, - group=group_name, - mode="w", - encoding=encoding, - ) - - logging.info(f"\n> DataFrame {group_name} written to zarr.") - - return None - - -def _read_df_from_zarr(root, group_name: str) -> pd.DataFrame: - ds = xr.open_zarr( - store=root.store, - group=group_name, - consolidated=False, - chunks="auto", - ) - da = ds["values"] - row_dim = list(ds.attrs.get("__row_dims__") or []) - col_dim = list(ds.attrs.get("__col_dims__") or []) - - # Series with MultiIndex index (row_dim, *col_dim) - series = da.to_series() - - # If there are column dims, unstack them back to columns - if col_dim: - df = series.unstack(col_dim) - else: - # No column dims -> a single column DataFrame - df = series.to_frame(name="values") - - col_name = [df.columns.name] - if not (row_dim == df.index.names): - df.index.set_names(row_dim, inplace=True) - if not (col_dim == col_name): - if isinstance(df.columns, pd.MultiIndex): - df.columns.set_names(col_dim, inplace=True) - else: - df.columns.name = col_dim[0] - - return df - - -def _default_names(names: list[str | None], prefix: str) -> list[str]: - # Replace None level names with defaults: f"{prefix}{i}" - return [n if n is not None else f"{prefix}{i}" for i, n in enumerate(names)] - - # >>> landmark responsive helpers >>> def to_df(mean, std, zone): out = pd.DataFrame({"mean": mean, "std": std}).reset_index() From 900fe4afe7837431c35c6c81569c2781e3bc95c8 Mon Sep 17 00:00:00 2001 From: amz_office Date: Wed, 22 Oct 2025 18:15:47 +0100 Subject: [PATCH 658/658] add sampling rate --- pixels/pixels_utils.py | 6 +++--- pixels/stream.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pixels/pixels_utils.py b/pixels/pixels_utils.py index 828956f..06da97b 100644 --- a/pixels/pixels_utils.py +++ b/pixels/pixels_utils.py @@ -1009,7 +1009,7 @@ def save_spike_chance_zarr( return None -def bin_chance_spikes(chance_data, time_bin, pos_bin, arr_path): +def bin_spike_chance(chance_data, sample_rate, time_bin, pos_bin, arr_path): # extract data from chance chance_spiked = chance_data["chance_spiked"] chance_fr = chance_data["chance_fr"] @@ -1052,7 +1052,7 @@ def bin_chance_spikes(chance_data, time_bin, pos_bin, arr_path): temp_fr[repeat][trial] = bin_vr_trial( data=fr, positions=trial_pos, - sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + sample_rate=sample_rate, time_bin=time_bin, pos_bin=pos_bin, bin_method="mean", # fr @@ -1061,7 +1061,7 @@ def bin_chance_spikes(chance_data, time_bin, pos_bin, arr_path): temp_spiked[repeat][trial] = bin_vr_trial( data=counts, positions=trial_pos, - sample_rate=self.BEHAVIOUR_SAMPLE_RATE, + sample_rate=sample_rate, time_bin=time_bin, pos_bin=pos_bin, bin_method="sum", # spike count diff --git a/pixels/stream.py b/pixels/stream.py index d039068..473d664 100644 --- a/pixels/stream.py +++ b/pixels/stream.py @@ -1239,8 +1239,9 @@ def get_binned_chance( end_event, ) # bin chance data - binned_chance = xut.bin_chance_spikes( + binned_chance = xut.bin_spike_chance( chance_data, + self.BEHAVIOUR_SAMPLE_RATE, time_bin, pos_bin, arr_path,