Skip to content

Commit ff351b1

Browse files
Changes to example notebook
1 parent 8a4dbe1 commit ff351b1

File tree

2 files changed

+107
-189
lines changed

2 files changed

+107
-189
lines changed

detectionmetrics/models/torch_detection.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,16 @@ def inference(self, image: Image.Image) -> Dict[str, torch.Tensor]:
284284
with torch.no_grad():
285285
result = self.model(tensor)[0] # Return only first image's result
286286

287+
# Apply threshold filtering from model config
288+
confidence_threshold = self.model_cfg.get("confidence_threshold", 0.5)
289+
if confidence_threshold > 0:
290+
keep_mask = result['scores'] >= confidence_threshold
291+
result = {
292+
'boxes': result['boxes'][keep_mask],
293+
'labels': result['labels'][keep_mask],
294+
'scores': result['scores'][keep_mask]
295+
}
296+
287297
return result
288298

289299
def eval(
@@ -352,12 +362,22 @@ def eval(
352362
gt = targets[i]
353363
pred = predictions[i]
354364

365+
# Apply confidence threshold filtering
366+
confidence_threshold = self.model_cfg.get("confidence_threshold", 0.5)
367+
if confidence_threshold > 0:
368+
keep_mask = pred['scores'] >= confidence_threshold
369+
pred = {
370+
'boxes': pred['boxes'][keep_mask],
371+
'labels': pred['labels'][keep_mask],
372+
'scores': pred['scores'][keep_mask]
373+
}
374+
355375
# Apply ontology translation if needed
356376
if lut_ontology is not None:
357377
gt["labels"] = lut_ontology[gt["labels"]]
358378

359379
# Update metrics
360-
metrics_factory.update(gt["boxes"], gt["labels"],pred["boxes"], pred["labels"], pred["scores"])
380+
metrics_factory.update(gt["boxes"], gt["labels"], pred["boxes"], pred["labels"], pred["scores"])
361381

362382
# Store predictions if needed
363383
if predictions_outdir is not None:

examples/tutorial_image_detection.ipynb

Lines changed: 86 additions & 188 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)