@@ -74,41 +74,41 @@ def update_model_toy(self, param:ParamClassifier):
7474 self .model_toy = model .to (self .device )
7575 self .preprocessing_artifact = PreProcessing .from_param (self .param_artifact_preprocessing )
7676
77- def artifact_detection (self , HFO_features , ignore_region , threshold = 0.5 ):
77+ def artifact_detection (self , biomarker_features , ignore_region , threshold = 0.5 ):
7878 if not self .model_toy :
7979 raise ValueError ("Please load artifact model first!" )
80- return self ._classify_artifacts (self .model_toy , HFO_features , ignore_region , threshold = threshold )
80+ return self ._classify_artifacts (self .model_toy , biomarker_features , ignore_region , threshold = threshold )
8181
82- def spike_detection (self , HFO_features ):
82+ def spike_detection (self , biomarker_features ):
8383 if not self .model_s :
8484 raise ValueError ("Please load spike model first!" )
85- return self ._classify_spikes (self .model_s , HFO_features )
85+ return self ._classify_spikes (self .model_s , biomarker_features )
8686
87- def _classify_artifacts (self , model , HFO_feature , ignore_region , threshold = 0.5 ):
87+ def _classify_artifacts (self , model , biomarker_feature , ignore_region , threshold = 0.5 ):
8888 model = model .to (self .device )
89- features = self .preprocessing_artifact .process_hfo_feature ( HFO_feature )
89+ features = self .preprocessing_artifact .process_biomarker_feature ( biomarker_feature )
9090 artifact_predictions = np .zeros (features .shape [0 ]) - 1
91- starts = HFO_feature .starts
92- ends = HFO_feature .ends
91+ starts = biomarker_feature .starts
92+ ends = biomarker_feature .ends
9393 keep_index = np .where (np .logical_and (starts > ignore_region [0 ], ends < ignore_region [1 ]) == True )[0 ]
9494 features = features [keep_index ]
9595 if len (features ) != 0 :
96- predictions = inference (model , features , self .device , self .batch_size , threshold = threshold )
96+ predictions = inference (model , features , self .device , self .batch_size , threshold = threshold )
9797 artifact_predictions [keep_index ] = predictions
98- HFO_feature .update_artifact_pred (artifact_predictions )
99- return HFO_feature
98+ biomarker_feature .update_artifact_pred (artifact_predictions )
99+ return biomarker_feature
100100
101- def _classify_spikes (self , model , HFO_feature ):
102- if len (HFO_feature .artifact_predictions ) == 0 :
101+ def _classify_spikes (self , model , biomarker_feature ):
102+ if len (biomarker_feature .artifact_predictions ) == 0 :
103103 raise ValueError ("Please run artifact classifier first!" )
104104 model = model .to (self .device )
105- features = self .preprocessing_spike .process_hfo_feature ( HFO_feature )
105+ features = self .preprocessing_spike .process_biomarker_feature ( biomarker_feature )
106106 spike_predictions = np .zeros (features .shape [0 ]) - 1
107- keep_index = np .where (HFO_feature .artifact_predictions > 0 )[0 ]
107+ keep_index = np .where (biomarker_feature .artifact_predictions > 0 )[0 ]
108108 features = features [keep_index ]
109109 if len (features ) != 0 :
110110 predictions = inference (model , features , self .device , self .batch_size )
111111 spike_predictions [keep_index ] = predictions
112- HFO_feature .update_spike_pred (spike_predictions )
113- return HFO_feature
112+ biomarker_feature .update_spike_pred (spike_predictions )
113+ return biomarker_feature
114114
0 commit comments