Skip to content

Commit 7e9f403

Browse files
committed
Update torchmodel.py
1 parent ae7b90e commit 7e9f403

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

golf_federated/client/process/config/model/torchmodel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,9 @@ def test(self):
323323
correct_all += correct
324324
query_y_c = query_y.cpu().numpy() # ----
325325
query_pred_c = query_pred.cpu().numpy() # ----
326-
precision = precision_score(query_pred_c, query_y_c, average='macro')
327-
recall = recall_score(query_pred_c, query_y_c, average='macro')
328-
f1 = f1_score(query_pred_c, query_y_c, average='macro')
326+
precision = precision_score(query_pred_c, query_y_c, average='macro',zero_division=0)
327+
recall = recall_score(query_pred_c, query_y_c, average='macro',zero_division=0)
328+
f1 = f1_score(query_pred_c, query_y_c, average='macro',zero_division=0)
329329
mcc = matthews_corrcoef(query_pred_c, query_y_c)
330330
precision_all += precision
331331
recall_all += recall
@@ -393,9 +393,9 @@ def localize(self):
393393
total += len(query_y)
394394
query_y_c = query_y.cpu().numpy()
395395
query_pred_c = query_pred.cpu().numpy()
396-
precision = precision_score(query_pred_c, query_y_c, average='macro')
397-
recall = recall_score(query_pred_c, query_y_c, average='macro')
398-
f1 = f1_score(query_pred_c, query_y_c, average='macro')
396+
precision = precision_score(query_pred_c, query_y_c, average='macro',zero_division=0)
397+
recall = recall_score(query_pred_c, query_y_c, average='macro',zero_division=0)
398+
f1 = f1_score(query_pred_c, query_y_c, average='macro',zero_division=0)
399399
mcc = matthews_corrcoef(query_pred_c, query_y_c)
400400
precision_all += precision
401401
recall_all += recall

0 commit comments

Comments
 (0)