Skip to content

Commit 11438d6

Browse files
committed
code refactor
1 parent 1a43d1c commit 11438d6

File tree

5 files changed

+259
-206
lines changed

5 files changed

+259
-206
lines changed

spike_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import mne
2+
import numpy as np
3+
from scipy.signal import hilbert, find_peaks
4+
from scipy.signal import detrend
5+
6+
# Load the EDF file using MNE
7+
file_path = 'MV_2.edf' # Replace with your actual file path
8+
raw_data = mne.io.read_raw_edf(file_path, preload=True)
9+
10+
# Get basic information about the file
11+
info = raw_data.info
12+
signal_labels = raw_data.ch_names
13+
sampling_rate = raw_data.info['sfreq']
14+
15+
# Apply notch filter to remove 60 Hz noise
16+
raw_notched = raw_data.copy().notch_filter(freqs=60, method='fir', fir_design='firwin')
17+
18+
# Bandpass filter the data between 25 and 80 Hz using MNE's built-in FIR filter
19+
low_cutoff = 25 # Lower bound of the bandpass filter in Hz
20+
high_cutoff = 80 # Upper bound of the bandpass filter in Hz
21+
raw_filtered = raw_notched.copy().filter(l_freq=low_cutoff, h_freq=high_cutoff, method='fir', fir_design='firwin', phase='zero-double')
22+
23+
24+
# Extract the filtered signal and unfiltered signal for the first channel
25+
filtered_signal_data, filtered_times = raw_filtered[0, :]
26+
unfiltered_signal_data, _ = raw_notched[0, :]
27+
28+
# Apply Hilbert transform to the filtered signal
29+
analytic_signal = hilbert(filtered_signal_data[0])
30+
31+
# Calculate the amplitude envelope (magnitude of the analytic signal)
32+
amplitude_envelope = np.abs(analytic_signal)
33+
34+
# Calculate the mean amplitude of the envelope
35+
mean_amplitude = np.mean(amplitude_envelope)
36+
37+
# Identify candidate spikes where the amplitude envelope exceeds 3 times the mean amplitude
38+
threshold = 3 * mean_amplitude
39+
candidate_spikes = np.where(amplitude_envelope > threshold)[0]
40+
41+
# Time window around each candidate spike (±0.25 s)
42+
window_size = int(0.25 * sampling_rate)
43+
44+
# List to store the valid spikes
45+
valid_spikes = []
46+
47+
for spike_idx in candidate_spikes:
48+
# Get the time window around the spike for unfiltered data
49+
start_idx = max(0, spike_idx - window_size)
50+
end_idx = min(len(unfiltered_signal_data[0]), spike_idx + window_size)
51+
52+
signal_window = unfiltered_signal_data[0][start_idx:end_idx]
53+
54+
# Detrend the window
55+
detrended_signal = detrend(signal_window)
56+
57+
# Identify peaks and troughs
58+
peaks, _ = find_peaks(detrended_signal)
59+
troughs, _ = find_peaks(-detrended_signal)
60+
61+
# Calculate the Fano factor
62+
if len(peaks) > 1 and len(troughs) > 1:
63+
# Calculate inter-peak and inter-trough intervals
64+
inter_peak_intervals = np.diff(peaks) / sampling_rate # Convert to seconds
65+
inter_trough_intervals = np.diff(troughs) / sampling_rate
66+
67+
peak_fano_factor = np.var(inter_peak_intervals) / np.mean(inter_peak_intervals)
68+
trough_fano_factor = np.var(inter_trough_intervals) / np.mean(inter_trough_intervals)
69+
70+
fano_factor = (peak_fano_factor + trough_fano_factor) / 2 # Average
71+
72+
# Calculate the maximum amplitude in the window (rectified)
73+
max_amplitude = np.max(np.abs(signal_window))
74+
75+
# Check the conditions for valid spikes
76+
if max_amplitude > 3 * mean_amplitude and fano_factor >= 2.5:
77+
valid_spikes.append(spike_idx)
78+
79+
# Merge spikes detected within 20 ms of each other
80+
merged_spikes = []
81+
previous_spike = -np.inf
82+
for spike in valid_spikes:
83+
if spike - previous_spike > 0.02 * sampling_rate:
84+
merged_spikes.append(spike)
85+
previous_spike = spike
86+
87+
# Display final spike indices and times
88+
spike_times = filtered_times[merged_spikes]
89+
print(f"Detected spikes at times (s): {spike_times}")

src/controllers/main_window_controller.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,23 @@ def create_hfo_window(self):
6767
self.model.init_param('HFO')
6868

6969
def create_spindle_window(self):
70-
print('not implemented yet')
70+
# create detection parameters stacked widget
71+
self.view.create_stacked_widget_detection_param('Spindle')
72+
73+
# create biomarker typ frame widget
74+
self.view.create_frame_biomarker('Spindle')
75+
76+
# manage flag
77+
self.view.window.is_data_filtered = False
78+
79+
# create center waveform and mini plot
80+
self.model.create_center_waveform_and_mini_plot()
81+
82+
# connect signal & slot
83+
self.model.connect_signal_and_slot('Spindle')
84+
85+
# init params
86+
self.model.init_param('Spindle')
7187

7288
def create_hypsarrhythmia_window(self):
7389
print('not implemented yet')

src/models/main_window_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def init_param(self, biomarker_type='HFO'):
109109
self.set_ste_input_len(8)
110110
self.set_hil_input_len(8)
111111
elif biomarker_type == 'Spindle':
112-
a = 2
112+
print('init_param not implemented')
113113

114114
def init_default_filter_input_params(self):
115115
default_params = ParamFilter()
@@ -234,6 +234,10 @@ def connect_signal_and_slot(self, biomarker_type='HFO'):
234234
if not self.gpu:
235235
# disable gpu buttons
236236
self.window.default_gpu_button.setEnabled(False)
237+
elif biomarker_type == 'Spindle':
238+
self.window.overview_filter_button.clicked.connect(self.filter_data)
239+
# set filter button to be disabled by default
240+
self.window.overview_filter_button.setEnabled(False)
237241

238242
def set_classifier_param_display(self):
239243
classifier_param = self.backend.get_classifier_param()

src/ui/main_window.py

Lines changed: 0 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -32,204 +32,3 @@ def __init__(self):
3232
# initialize biomarker specific UI
3333
biomarker = self.main_window_controller.get_biomarker_type()
3434
self.main_window_controller.init_biomarker_window(biomarker)
35-
36-
def create_spindle_window(self):
37-
clear_stacked_widget(self.stacked_widget_detection_param)
38-
page_yasa = self.create_detection_parameter_page_yasa('Detection Parameters (YASA)')
39-
self.stacked_widget_detection_param.addWidget(page_yasa)
40-
41-
self.detector_subtabs.clear()
42-
tab_yasa = self.create_detection_parameter_tab_yasa()
43-
self.detector_subtabs.addTab(tab_yasa, 'YASA')
44-
45-
# create biomarker type
46-
clear_layout(self.frame_biomarker_layout)
47-
self.create_frame_biomarker_spindle()
48-
49-
self.overview_filter_button.clicked.connect(self.filter_data)
50-
# set filter button to be disabled by default
51-
self.overview_filter_button.setEnabled(False)
52-
# self.show_original_button.clicked.connect(self.toggle_filtered)
53-
54-
self.is_data_filtered = False
55-
56-
self.waveform_plot = CenterWaveformAndMiniPlotController(self.waveform_plot_widget, self.waveform_mini_widget,
57-
self.hfo_app)
58-
59-
def create_hypsarrhythmia_window(self):
60-
print('not implemented yet')
61-
62-
def create_frame_biomarker_spindle(self):
63-
# self.frame_biomarker_layout = QHBoxLayout(self.frame_biomarker_type)
64-
self.frame_biomarker_layout.addStretch(1)
65-
66-
# Add three QLabel widgets to the QFrame
67-
self.label_type1 = QLabel("Artifact")
68-
self.label_type1.setFixedWidth(150)
69-
self.label_type2 = QLabel("spk-Spindle")
70-
self.label_type2.setFixedWidth(150)
71-
self.label_type3 = QLabel("Spindle")
72-
self.label_type3.setFixedWidth(150)
73-
74-
self.line_type1 = QLineEdit()
75-
self.line_type1.setReadOnly(True)
76-
self.line_type1.setFrame(True)
77-
self.line_type1.setFixedWidth(50)
78-
self.line_type1.setStyleSheet("background-color: orange;")
79-
self.line_type2 = QLineEdit()
80-
self.line_type2.setReadOnly(True)
81-
self.line_type2.setFrame(True)
82-
self.line_type2.setFixedWidth(50)
83-
self.line_type2.setStyleSheet("background-color: purple;")
84-
self.line_type3 = QLineEdit()
85-
self.line_type3.setReadOnly(True)
86-
self.line_type3.setFrame(True)
87-
self.line_type3.setFixedWidth(50)
88-
self.line_type3.setStyleSheet("background-color: green;")
89-
90-
# Add labels to the layout
91-
self.frame_biomarker_layout.addWidget(self.line_type1)
92-
self.frame_biomarker_layout.addWidget(self.label_type1)
93-
self.frame_biomarker_layout.addWidget(self.line_type2)
94-
self.frame_biomarker_layout.addWidget(self.label_type2)
95-
self.frame_biomarker_layout.addWidget(self.line_type3)
96-
self.frame_biomarker_layout.addWidget(self.label_type3)
97-
self.frame_biomarker_layout.addStretch(1)
98-
99-
def create_detection_parameter_page_yasa(self, groupbox_title):
100-
page = QWidget()
101-
layout = QGridLayout()
102-
103-
detection_groupbox_yasa = QGroupBox(groupbox_title)
104-
yasa_parameter_layout = QGridLayout(detection_groupbox_yasa)
105-
106-
clear_layout(yasa_parameter_layout)
107-
# self.detection_groupbox_hil.setTitle("Detection Parameters (HIL)")
108-
109-
# Create widgets
110-
text_font = QFont('Arial', 11)
111-
label1 = QLabel('Freq Spindle (Hz)')
112-
label2 = QLabel('Freq Broad (Hz)')
113-
label3 = QLabel('Duration (s)')
114-
label4 = QLabel('Min Distance (ms)')
115-
label5 = QLabel('rel_pow')
116-
label6 = QLabel('corr')
117-
label7 = QLabel('rms')
118-
119-
self.yasa_freq_sp_display = QLabel()
120-
self.yasa_freq_sp_display.setStyleSheet("background-color: rgb(235, 235, 235);")
121-
self.yasa_freq_sp_display.setFont(text_font)
122-
self.yasa_freq_broad_display = QLabel()
123-
self.yasa_freq_broad_display.setStyleSheet("background-color: rgb(235, 235, 235);")
124-
self.yasa_freq_broad_display.setFont(text_font)
125-
self.yasa_duration_display = QLabel()
126-
self.yasa_duration_display.setStyleSheet("background-color: rgb(235, 235, 235);")
127-
self.yasa_duration_display.setFont(text_font)
128-
self.yasa_min_distance_display = QLabel()
129-
self.yasa_min_distance_display.setStyleSheet("background-color: rgb(235, 235, 235);")
130-
self.yasa_min_distance_display.setFont(text_font)
131-
self.yasa_thresh_rel_pow_display = QLabel()
132-
self.yasa_thresh_rel_pow_display.setStyleSheet("background-color: rgb(235, 235, 235);")
133-
self.yasa_thresh_rel_pow_display.setFont(text_font)
134-
self.yasa_thresh_corr_display = QLabel()
135-
self.yasa_thresh_corr_display.setStyleSheet("background-color: rgb(235, 235, 235);")
136-
self.yasa_thresh_corr_display.setFont(text_font)
137-
self.yasa_thresh_rms_display = QLabel()
138-
self.yasa_thresh_rms_display.setStyleSheet("background-color: rgb(235, 235, 235);")
139-
self.yasa_thresh_rms_display.setFont(text_font)
140-
141-
self.yasa_detect_button = QPushButton('Detect')
142-
143-
# Add widgets to the grid layout
144-
yasa_parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0
145-
yasa_parameter_layout.addWidget(label2, 0, 1) # Row 0, Column 1
146-
yasa_parameter_layout.addWidget(self.yasa_freq_sp_display, 1, 0) # Row 1, Column 0
147-
yasa_parameter_layout.addWidget(self.yasa_freq_broad_display, 1, 1) # Row 1, Column 1
148-
yasa_parameter_layout.addWidget(label3, 2, 0)
149-
yasa_parameter_layout.addWidget(label4, 2, 1)
150-
yasa_parameter_layout.addWidget(self.yasa_duration_display, 3, 0)
151-
yasa_parameter_layout.addWidget(self.yasa_min_distance_display, 3, 1)
152-
153-
group_box = QGroupBox('thresh')
154-
thresh_parameter_layout = QVBoxLayout(group_box)
155-
thresh_parameter_layout.addWidget(label5)
156-
thresh_parameter_layout.addWidget(self.yasa_thresh_rel_pow_display)
157-
thresh_parameter_layout.addWidget(label6)
158-
thresh_parameter_layout.addWidget(self.yasa_thresh_corr_display)
159-
thresh_parameter_layout.addWidget(label7)
160-
thresh_parameter_layout.addWidget(self.yasa_thresh_rms_display)
161-
162-
yasa_parameter_layout.addWidget(group_box, 0, 2, 4, 1) # Row 0, Column 2, span 1 row, 6 columns
163-
yasa_parameter_layout.addWidget(self.mni_detect_button, 4, 2)
164-
165-
# Set the layout for the page
166-
layout.addWidget(detection_groupbox_yasa)
167-
page.setLayout(layout)
168-
return page
169-
170-
def create_detection_parameter_tab_yasa(self):
171-
tab = QWidget()
172-
layout = QGridLayout()
173-
174-
detection_groupbox = QGroupBox('Detection Parameters')
175-
parameter_layout = QGridLayout(detection_groupbox)
176-
177-
clear_layout(parameter_layout)
178-
179-
# Create widgets
180-
text_font = QFont('Arial', 11)
181-
label1 = QLabel('Freq Spindle')
182-
label2 = QLabel('Freq Broad')
183-
label3 = QLabel('Duration')
184-
label4 = QLabel('Min Distance')
185-
label5 = QLabel('Thresh-rel_pow')
186-
label6 = QLabel('Thresh-corr')
187-
label7 = QLabel('Thresh-rms')
188-
label8 = QLabel('Hz')
189-
label9 = QLabel('Hz')
190-
label10 = QLabel('sec')
191-
label11 = QLabel('ms')
192-
193-
self.yasa_freq_sp_input = QLineEdit()
194-
self.yasa_freq_sp_input.setFont(text_font)
195-
self.yasa_freq_broad_input = QLineEdit()
196-
self.yasa_freq_broad_input.setFont(text_font)
197-
self.yasa_duration_input = QLineEdit()
198-
self.yasa_duration_input.setFont(text_font)
199-
self.yasa_min_distance_input = QLineEdit()
200-
self.yasa_min_distance_input.setFont(text_font)
201-
self.yasa_thresh_rel_pow_input = QLineEdit()
202-
self.yasa_thresh_rel_pow_input.setFont(text_font)
203-
self.yasa_thresh_corr_input = QLineEdit()
204-
self.yasa_thresh_corr_input.setFont(text_font)
205-
self.yasa_thresh_rms_input = QLineEdit()
206-
self.yasa_thresh_rms_input.setFont(text_font)
207-
self.YASA_save_button = QPushButton('Save')
208-
209-
# Add widgets to the grid layout
210-
parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0
211-
parameter_layout.addWidget(self.yasa_freq_sp_input, 0, 1) # Row 0, Column 1
212-
parameter_layout.addWidget(label8, 0, 2)
213-
parameter_layout.addWidget(label2, 1, 0)
214-
parameter_layout.addWidget(self.yasa_freq_broad_input, 1, 1)
215-
parameter_layout.addWidget(label9, 1, 2)
216-
parameter_layout.addWidget(label3, 2, 0)
217-
parameter_layout.addWidget(self.yasa_duration_input, 2, 1)
218-
parameter_layout.addWidget(label10, 2, 2)
219-
parameter_layout.addWidget(label4, 3, 0)
220-
parameter_layout.addWidget(self.yasa_min_distance_input, 3, 1)
221-
parameter_layout.addWidget(label11, 3, 2)
222-
223-
parameter_layout.addWidget(label5, 4, 0)
224-
parameter_layout.addWidget(self.yasa_thresh_rel_pow_input, 4, 1)
225-
parameter_layout.addWidget(label6, 5, 0)
226-
parameter_layout.addWidget(self.yasa_thresh_corr_input, 5, 1)
227-
parameter_layout.addWidget(label7, 6, 0)
228-
parameter_layout.addWidget(self.yasa_thresh_rms_input, 6, 1)
229-
230-
parameter_layout.addWidget(self.YASA_save_button, 7, 2)
231-
232-
# Set the layout for the page
233-
layout.addWidget(detection_groupbox)
234-
tab.setLayout(layout)
235-
return tab

0 commit comments

Comments
 (0)