@@ -323,9 +323,9 @@ def test(self):
323323 correct_all += correct
324324 query_y_c = query_y .cpu ().numpy () # ----
325325 query_pred_c = query_pred .cpu ().numpy () # ----
326- precision = precision_score (query_pred_c , query_y_c , average = 'macro' )
327- recall = recall_score (query_pred_c , query_y_c , average = 'macro' )
328- f1 = f1_score (query_pred_c , query_y_c , average = 'macro' )
326+ precision = precision_score (query_pred_c , query_y_c , average = 'macro' , zero_division = 0 )
327+ recall = recall_score (query_pred_c , query_y_c , average = 'macro' , zero_division = 0 )
328+ f1 = f1_score (query_pred_c , query_y_c , average = 'macro' , zero_division = 0 )
329329 mcc = matthews_corrcoef (query_pred_c , query_y_c )
330330 precision_all += precision
331331 recall_all += recall
@@ -393,9 +393,9 @@ def localize(self):
393393 total += len (query_y )
394394 query_y_c = query_y .cpu ().numpy ()
395395 query_pred_c = query_pred .cpu ().numpy ()
396- precision = precision_score (query_pred_c , query_y_c , average = 'macro' )
397- recall = recall_score (query_pred_c , query_y_c , average = 'macro' )
398- f1 = f1_score (query_pred_c , query_y_c , average = 'macro' )
396+ precision = precision_score (query_pred_c , query_y_c , average = 'macro' , zero_division = 0 )
397+ recall = recall_score (query_pred_c , query_y_c , average = 'macro' , zero_division = 0 )
398+ f1 = f1_score (query_pred_c , query_y_c , average = 'macro' , zero_division = 0 )
399399 mcc = matthews_corrcoef (query_pred_c , query_y_c )
400400 precision_all += precision
401401 recall_all += recall
0 commit comments