Skip to content

Commit 5cd2ec6

Browse files
Fixed the issue with true positive always being 0.
1 parent ff351b1 commit 5cd2ec6

File tree

3 files changed

+1664
-85
lines changed

3 files changed

+1664
-85
lines changed

detectionmetrics/models/torch_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def eval(
345345
)
346346

347347
# Init metrics
348-
metrics_factory = um.DetectionMetricsFactory(self.n_classes)
348+
metrics_factory = um.DetectionMetricsFactory(iou_threshold=0.5,num_classes=self.n_classes)
349349

350350
with torch.no_grad():
351351
pbar = tqdm(dataloader, leave=True)

detectionmetrics/utils/detection_metrics.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ def update(self, gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores):
2828
if hasattr(pred_labels, 'detach'): pred_labels = pred_labels.detach().cpu().numpy()
2929
if hasattr(pred_scores, 'detach'): pred_scores = pred_scores.detach().cpu().numpy()
3030

31+
# Handle empty inputs
32+
if len(gt_boxes) == 0 and len(pred_boxes) == 0:
33+
return # Nothing to process
34+
35+
# Handle case where there are predictions but no ground truth
36+
if len(gt_boxes) == 0:
37+
for p_label, score in zip(pred_labels, pred_scores):
38+
self.results[p_label].append((score, 0)) # All are false positives
39+
return
40+
41+
# Handle case where there is ground truth but no predictions
42+
if len(pred_boxes) == 0:
43+
for g_label in gt_labels:
44+
self.results[g_label].append((None, -1)) # All are false negatives
45+
return
46+
3147
matches = self._match_predictions(
3248
gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores
3349
)
@@ -147,6 +163,14 @@ def get_metrics_dataframe(self, ontology: dict) -> pd.DataFrame:
147163
values = [v for v in metrics_dict[metric].values() if not pd.isna(v)]
148164
metrics_dict[metric]["mean"] = np.mean(values) if values else np.nan
149165

166+
# Add overall mAP if available
167+
if -1 in all_metrics:
168+
for metric in ["AP", "Precision", "Recall", "TP", "FP", "FN"]:
169+
if metric == "AP":
170+
metrics_dict[metric]["mAP"] = all_metrics[-1].get(metric, np.nan)
171+
else:
172+
metrics_dict[metric]["mAP"] = np.nan
173+
150174
df = pd.DataFrame(metrics_dict)
151175
return df.T # metrics as rows, classes as columns (with mean)
152176

@@ -177,15 +201,28 @@ def compute_ap(tps, fps, fn):
177201
tps = np.array(tps, dtype=np.float32)
178202
fps = np.array(fps, dtype=np.float32)
179203

204+
# Handle edge cases
205+
if len(tps) == 0:
206+
if fn == 0:
207+
return 1.0, [1.0], [1.0] # Perfect case: no predictions, no ground truth
208+
else:
209+
return 0.0, [0.0], [0.0] # No predictions but there was ground truth
210+
180211
tp_cumsum = np.cumsum(tps)
181212
fp_cumsum = np.cumsum(fps)
182213

183214
if tp_cumsum.size:
184215
denom = tp_cumsum[-1] + fn
185-
recalls = tp_cumsum / denom
216+
if denom > 0:
217+
recalls = tp_cumsum / denom
218+
else:
219+
recalls = np.zeros_like(tp_cumsum)
186220
else:
187221
recalls = []
188-
precisions = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6) if tp_cumsum.size else []
222+
223+
# Compute precision with proper handling of division by zero
224+
denominator = tp_cumsum + fp_cumsum
225+
precisions = np.where(denominator > 0, tp_cumsum / denominator, 0.0)
189226

190227
# VOC-style 11-point interpolation
191228
ap = 0

examples/tutorial_image_detection.ipynb

Lines changed: 1624 additions & 82 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)