Skip to content

Commit d30ea21

Browse files
committed
add spindle and bug fix
1 parent 11438d6 commit d30ea21

31 files changed

+1532
-415
lines changed

main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
import re
33
import sys
44
from src.ui.main_window import MainWindow
5+
from PyQt5.QtCore import Qt
56
from PyQt5.QtWidgets import *
67
import multiprocessing as mp
78
import torch
89
import warnings
910
warnings.filterwarnings("ignore")
1011

12+
# Enable DPI scaling
13+
QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
14+
QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True)
15+
1116

1217
def closeAllWindows():
1318
QApplication.instance().closeAllWindows()

src/classifer.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,41 +74,41 @@ def update_model_toy(self, param:ParamClassifier):
7474
self.model_toy = model.to(self.device)
7575
self.preprocessing_artifact = PreProcessing.from_param(self.param_artifact_preprocessing)
7676

77-
def artifact_detection(self, HFO_features, ignore_region, threshold=0.5):
77+
def artifact_detection(self, biomarker_features, ignore_region, threshold=0.5):
7878
if not self.model_toy:
7979
raise ValueError("Please load artifact model first!")
80-
return self._classify_artifacts(self.model_toy, HFO_features, ignore_region, threshold=threshold)
80+
return self._classify_artifacts(self.model_toy, biomarker_features, ignore_region, threshold=threshold)
8181

82-
def spike_detection(self, HFO_features):
82+
def spike_detection(self, biomarker_features):
8383
if not self.model_s:
8484
raise ValueError("Please load spike model first!")
85-
return self._classify_spikes(self.model_s, HFO_features)
85+
return self._classify_spikes(self.model_s, biomarker_features)
8686

87-
def _classify_artifacts(self, model, HFO_feature, ignore_region, threshold=0.5):
87+
def _classify_artifacts(self, model, biomarker_feature, ignore_region, threshold=0.5):
8888
model = model.to(self.device)
89-
features = self.preprocessing_artifact.process_hfo_feature(HFO_feature)
89+
features = self.preprocessing_artifact.process_biomarker_feature(biomarker_feature)
9090
artifact_predictions = np.zeros(features.shape[0]) -1
91-
starts = HFO_feature.starts
92-
ends = HFO_feature.ends
91+
starts = biomarker_feature.starts
92+
ends = biomarker_feature.ends
9393
keep_index = np.where(np.logical_and(starts > ignore_region[0], ends < ignore_region[1]) == True)[0]
9494
features = features[keep_index]
9595
if len(features) != 0:
96-
predictions = inference(model, features, self.device ,self.batch_size, threshold=threshold)
96+
predictions = inference(model, features, self.device, self.batch_size, threshold=threshold)
9797
artifact_predictions[keep_index] = predictions
98-
HFO_feature.update_artifact_pred(artifact_predictions)
99-
return HFO_feature
98+
biomarker_feature.update_artifact_pred(artifact_predictions)
99+
return biomarker_feature
100100

101-
def _classify_spikes(self, model, HFO_feature):
102-
if len(HFO_feature.artifact_predictions) == 0:
101+
def _classify_spikes(self, model, biomarker_feature):
102+
if len(biomarker_feature.artifact_predictions) == 0:
103103
raise ValueError("Please run artifact classifier first!")
104104
model = model.to(self.device)
105-
features = self.preprocessing_spike.process_hfo_feature(HFO_feature)
105+
features = self.preprocessing_spike.process_biomarker_feature(biomarker_feature)
106106
spike_predictions = np.zeros(features.shape[0]) -1
107-
keep_index = np.where(HFO_feature.artifact_predictions > 0)[0]
107+
keep_index = np.where(biomarker_feature.artifact_predictions > 0)[0]
108108
features = features[keep_index]
109109
if len(features) != 0:
110110
predictions = inference(model, features, self.device, self.batch_size)
111111
spike_predictions[keep_index] = predictions
112-
HFO_feature.update_spike_pred(spike_predictions)
113-
return HFO_feature
112+
biomarker_feature.update_spike_pred(spike_predictions)
113+
return biomarker_feature
114114

src/controllers/annotation_controller.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,23 @@ def __init__(self, annotation_widget, backend=None):
88
self.model = AnnotationModel(backend)
99
self.view = AnnotationView(annotation_widget)
1010

11+
# define window length
12+
if self.model.backend.biomarker_type == "HFO":
13+
self.interval = 1.0
14+
elif self.model.backend.biomarker_type == "Spindle":
15+
self.interval = 4.0
16+
1117
def create_waveform_plot(self):
1218
self.model.create_waveform_plot()
1319
self.view.add_widget('VisulaizationVerticalLayout', self.model.waveform_plot)
1420
channel, start, end = self.get_current_event()
15-
self.model.waveform_plot.plot(start, end, channel, interval=1.0) # Default interval
21+
self.model.waveform_plot.plot(start, end, channel, interval=self.interval) # Default interval
1622

1723
def create_fft_plot(self):
1824
self.model.create_fft_plot()
1925
self.view.add_widget('FFT_layout', self.model.fft_plot)
2026
channel, start, end = self.get_current_event()
21-
self.model.fft_plot.plot(start, end, channel, interval=1.0) # Default interval
27+
self.model.fft_plot.plot(start, end, channel, interval=self.interval) # Default interval
2228

2329
def update_plots(self, start, end, channel, interval):
2430
self.model.waveform_plot.plot(start, end, channel, interval=interval)

src/controllers/main_waveform_plot_controller.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def get_current_start_end(self):
6363
def get_current_time_window(self):
6464
return self.model.get_current_time_window()
6565

66-
def set_plot_HFOs(self, plot_HFOs:bool):
67-
self.model.set_plot_HFOs(plot_HFOs)
66+
def set_plot_biomarkers(self, plot_biomarkers:bool):
67+
self.model.set_plot_biomarkers(plot_biomarkers)
6868

6969
def plot_all_current_channels_for_window(self):
7070
eeg_data_to_display, y_100_length, y_scale_length, offset_value = self.get_current_eeg_data_to_display()
@@ -78,22 +78,24 @@ def plot_all_current_channels_for_window(self):
7878

7979
return eeg_data_to_display, y_100_length, y_scale_length, offset_value
8080

81-
def plot_all_current_hfos_for_window(self, eeg_data_to_display, offset_value, top_value):
81+
def plot_all_current_biomarkers_for_window(self, eeg_data_to_display, offset_value, top_value):
8282
first_channel_to_plot = self.get_first_channel_to_plot()
8383
n_channels_to_plot = self.model.n_channels_to_plot
8484
channels_to_plot = self.model.channels_to_plot
8585
start_in_time, end_in_time = self.get_current_start_end()
8686

8787
for disp_i, ch_i in enumerate(range(first_channel_to_plot, first_channel_to_plot+n_channels_to_plot)):
8888
channel = channels_to_plot[ch_i]
89-
hfo_starts, hfo_ends, hfo_starts_in_time, hfo_ends_in_time, windows_in_time, colors = self.model.get_all_hfos_for_all_current_channels_and_color(channel)
90-
91-
if self.model.plot_HFOs:
92-
for i in range(len(hfo_starts)):
93-
event_start = int(hfo_starts[i]-start_in_time*self.model.sample_freq)
94-
event_end = int(hfo_ends[i]-start_in_time*self.model.sample_freq)
89+
(biomarker_starts, biomarker_ends,
90+
biomarker_starts_in_time, biomarker_ends_in_time,
91+
windows_in_time, colors) = self.model.get_all_biomarkers_for_all_current_channels_and_color(channel)
92+
93+
if self.model.plot_biomarkers:
94+
for i in range(len(biomarker_starts)):
95+
event_start = int(biomarker_starts[i]-start_in_time*self.model.sample_freq)
96+
event_end = int(biomarker_ends[i]-start_in_time*self.model.sample_freq)
9597
self.view.plot_waveform(windows_in_time[i], eeg_data_to_display[ch_i, event_start:event_end]-disp_i*offset_value, colors[i], 2)
96-
self.view.plot_waveform([hfo_starts_in_time[i], hfo_ends_in_time[i]], [top_value+0.2,top_value+0.2], colors[i], 10)
98+
self.view.plot_waveform([biomarker_starts_in_time[i], biomarker_ends_in_time[i]], [top_value+0.2,top_value+0.2], colors[i], 10)
9799

98100

99101
def draw_scale_bar(self, eeg_data_to_display, offset_value, y_100_length, y_scale_length):

src/controllers/main_window_controller.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from src.utils.utils_gui import *
2-
from src.hfo_app import HFO_App
32

43

54
class MainWindowController:
@@ -10,7 +9,7 @@ def __init__(self, view, model):
109
self.supported_biomarker = {
1110
'HFO': self.create_hfo_window,
1211
'Spindle': self.create_spindle_window,
13-
'Hypsarrhythmia': self.create_hypsarrhythmia_window,
12+
'Spike': self.create_spike_window,
1413
}
1514

1615
def init_biomarker_window(self, biomarker_type):
@@ -26,16 +25,12 @@ def init_general_window(self):
2625
def get_biomarker_type(self):
2726
return self.view.get_biomarker_type()
2827

28+
def set_biomarker_type(self, bio_type):
29+
self.model.set_biomarker_type_and_init_backend(bio_type)
30+
2931
def init_biomarker_type(self):
3032
default_biomarker = self.get_biomarker_type()
31-
backend = None
32-
if default_biomarker == 'HFO':
33-
backend = HFO_App()
34-
elif default_biomarker == 'Spindle':
35-
backend = HFO_App()
36-
elif default_biomarker == 'Spike':
37-
backend = HFO_App()
38-
self.model.set_backend(backend)
33+
self.set_biomarker_type(default_biomarker)
3934

4035
self.view.window.combo_box_biomarker.currentIndexChanged.connect(self.switch_biomarker)
4136

@@ -44,6 +39,9 @@ def switch_biomarker(self):
4439
self.supported_biomarker[selected_biomarker]()
4540

4641
def create_hfo_window(self):
42+
# set biomarker type
43+
self.set_biomarker_type('HFO')
44+
4745
# dynamically create frame for different biomarkers
4846
self.view.window.frame_biomarker_layout = QHBoxLayout(self.view.window.frame_biomarker_type)
4947

@@ -67,6 +65,9 @@ def create_hfo_window(self):
6765
self.model.init_param('HFO')
6866

6967
def create_spindle_window(self):
68+
# set biomarker type
69+
self.set_biomarker_type('Spindle')
70+
7071
# create detection parameters stacked widget
7172
self.view.create_stacked_widget_detection_param('Spindle')
7273

@@ -85,5 +86,5 @@ def create_spindle_window(self):
8586
# init params
8687
self.model.init_param('Spindle')
8788

88-
def create_hypsarrhythmia_window(self):
89+
def create_spike_window(self):
8990
print('not implemented yet')

src/controllers/mini_plot_controller.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, mini_plot_widget, backend):
1111
def clear(self):
1212
self.view.clear()
1313

14-
def init_hfo_display(self):
14+
def init_biomarker_display(self):
1515
self.view.enable_axis_information()
1616
self.view.add_linear_region()
1717

@@ -30,25 +30,28 @@ def set_channels_to_plot(self, channels_to_plot):
3030
def set_n_channels_to_plot(self, n_channels_to_plot):
3131
self.model.set_n_channels_to_plot(n_channels_to_plot)
3232

33+
def set_first_channel_to_plot(self, first_channel_to_plot):
34+
self.model.set_first_channel_to_plot(first_channel_to_plot)
35+
3336
def update_channel_names(self, new_channel_names):
3437
self.model.update_channel_names(new_channel_names)
3538

36-
def plot_one_hfo(self, start_time, end_time, height, color, width=5):
37-
self.view.plot_hfo(start_time, end_time, height, color, width)
39+
def plot_one_biomarker(self, start_time, end_time, height, color, width=5):
40+
self.view.plot_biomarker(start_time, end_time, height, color, width)
3841

39-
def plot_all_current_hfos_for_one_channel(self, channel, plot_height):
40-
starts_in_time, ends_in_time, colors = self.model.get_all_hfos_for_channel_and_color(channel)
42+
def plot_all_current_biomarkers_for_one_channel(self, channel, plot_height):
43+
starts_in_time, ends_in_time, colors = self.model.get_all_biomarkers_for_channel_and_color(channel)
4144

4245
for i in range(len(starts_in_time)):
43-
self.plot_one_hfo(starts_in_time[i], ends_in_time[i], plot_height, colors[i], 5)
46+
self.plot_one_biomarker(starts_in_time[i], ends_in_time[i], plot_height, colors[i], 5)
4447

45-
def plot_all_current_hfos_for_all_channels(self, plot_height):
48+
def plot_all_current_biomarkers_for_all_channels(self, plot_height):
4649
first_channel_to_plot = self.get_first_channel_to_plot()
4750
n_channels_to_plot = self.model.n_channels_to_plot
4851
channels_to_plot = self.model.channels_to_plot
4952
for disp_i, ch_i in enumerate(range(first_channel_to_plot, first_channel_to_plot+n_channels_to_plot)):
5053
channel = channels_to_plot[ch_i]
51-
self.plot_all_current_hfos_for_one_channel(channel, plot_height)
54+
self.plot_all_current_biomarkers_for_one_channel(channel, plot_height)
5255

5356
def set_miniplot_title(self, title, height):
5457
self.view.set_miniplot_title(title, height)

0 commit comments

Comments
 (0)