File tree Expand file tree Collapse file tree 5 files changed +17
-17
lines changed
Expand file tree Collapse file tree 5 files changed +17
-17
lines changed Original file line number Diff line number Diff line change @@ -438,7 +438,7 @@ def eval(
438438 ignored_label_indices .append (dataset_ontology [ignored_class ]["idx" ])
439439
440440 # Init metrics
441- metrics_factory = um .MetricsFactory (self .n_classes )
441+ metrics_factory = um .SegmentationMetricsFactory (self .n_classes )
442442
443443 # Evaluation loop
444444 pbar = tqdm (dataset .dataset )
@@ -481,7 +481,7 @@ def eval(
481481 sample_valid_mask = (
482482 valid_mask [i ] if valid_mask is not None else None
483483 )
484- sample_mf = um .MetricsFactory (n_classes = self .n_classes )
484+ sample_mf = um .SegmentationMetricsFactory (n_classes = self .n_classes )
485485 sample_mf .update (sample_pred , sample_label , sample_valid_mask )
486486 sample_df = um .get_metrics_dataframe (sample_mf , self .ontology )
487487 sample_df .to_csv (
Original file line number Diff line number Diff line change @@ -530,7 +530,7 @@ def eval(
530530 )
531531
532532 # Init metrics
533- metrics_factory = um .MetricsFactory (self .n_classes )
533+ metrics_factory = um .SegmentationMetricsFactory (self .n_classes )
534534
535535 # Evaluation loop
536536 with torch .no_grad ():
@@ -585,7 +585,7 @@ def eval(
585585 sample_valid_mask = (
586586 valid_mask [i ] if valid_mask is not None else None
587587 )
588- sample_mf = um .MetricsFactory (n_classes = self .n_classes )
588+ sample_mf = um .SegmentationMetricsFactory (n_classes = self .n_classes )
589589 sample_mf .update (
590590 sample_pred , sample_label , sample_valid_mask
591591 )
@@ -805,7 +805,7 @@ def eval(
805805 )
806806
807807 # Init metrics
808- metrics_factory = um .MetricsFactory (self .n_classes )
808+ metrics_factory = um .SegmentationMetricsFactory (self .n_classes )
809809
810810 # Evaluation loop
811811 end_th = self .model_cfg .get ("end_th" , 0.5 )
@@ -885,7 +885,7 @@ def eval(
885885 sample_valid_mask = (
886886 valid_mask [i ] if valid_mask is not None else None
887887 )
888- sample_mf = um .MetricsFactory (n_classes = self .n_classes )
888+ sample_mf = um .SegmentationMetricsFactory (n_classes = self .n_classes )
889889 sample_mf .update (
890890 sample_pred , sample_label , sample_valid_mask
891891 )
Original file line number Diff line number Diff line change 66import pandas as pd
77
88
9- class MetricsFactory :
10- """'Factory' class to accumulate results and compute metrics
9+ class SegmentationMetricsFactory :
10+ """'Factory' class to accumulate results and compute metrics for segmentation tasks
1111
1212 :param n_classes: Number of classes to evaluate
1313 :type n_classes: int
@@ -256,12 +256,12 @@ def get_metric_per_name(
256256
257257
258258def get_metrics_dataframe (
259- metrics_factory : MetricsFactory , ontology : dict
259+ metrics_factory : SegmentationMetricsFactory , ontology : dict
260260) -> pd .DataFrame :
261261 """Build a DataFrame with all metrics (global and per class) plus confusion matrix
262262
263- :param metrics_factory: Properly updated MetricsFactory object
264- :type metrics_factory: MetricsFactory
263+ :param metrics_factory: Properly updated SegmentationMetricsFactory object
264+ :type metrics_factory: SegmentationMetricsFactory
265265 :param ontology: Ontology dictionary
266266 :type ontology: dict
267267 :return: DataFrame with all metrics
Original file line number Diff line number Diff line change 314314 ],
315315 "metadata" : {
316316 "kernelspec" : {
317- "display_name" : " detectionmetrics-cJs_3AVd-py3.10 " ,
317+ "display_name" : " .venv " ,
318318 "language" : " python" ,
319319 "name" : " python3"
320320 },
328328 "name" : " python" ,
329329 "nbconvert_exporter" : " python" ,
330330 "pygments_lexer" : " ipython3" ,
331- "version" : " 3.10.13 "
331+ "version" : " 3.10.4 "
332332 }
333333 },
334334 "nbformat" : 4 ,
Original file line number Diff line number Diff line change 11import numpy as np
22import pytest
3- from detectionmetrics .utils .metrics import MetricsFactory
3+ from detectionmetrics .utils .metrics import SegmentationMetricsFactory
44
55
66@pytest .fixture
77def metrics_factory ():
8- """Fixture to create a MetricsFactory instance for testing"""
9- return MetricsFactory (n_classes = 3 )
8+ """Fixture to create a SegmentationMetricsFactory instance for testing"""
9+ return SegmentationMetricsFactory (n_classes = 3 )
1010
1111
1212def test_update_confusion_matrix (metrics_factory ):
@@ -94,7 +94,7 @@ def test_edge_cases(metrics_factory):
9494 with pytest .raises (AssertionError ):
9595 metrics_factory .update (pred , gt )
9696
97- empty_metrics_factory = MetricsFactory (n_classes = 3 )
97+ empty_metrics_factory = SegmentationMetricsFactory (n_classes = 3 )
9898
9999 assert np .isnan (empty_metrics_factory .get_precision (per_class = False ))
100100 assert np .isnan (empty_metrics_factory .get_recall (per_class = False ))
You can’t perform that action at this time.
0 commit comments