Skip to content

Commit d96d1aa

Browse files
Renamed old MetricsFactory into SegmentationMetricsFactory
1 parent 3c848bc commit d96d1aa

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

detectionmetrics/models/tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def eval(
438438
ignored_label_indices.append(dataset_ontology[ignored_class]["idx"])
439439

440440
# Init metrics
441-
metrics_factory = um.MetricsFactory(self.n_classes)
441+
metrics_factory = um.SegmentationMetricsFactory(self.n_classes)
442442

443443
# Evaluation loop
444444
pbar = tqdm(dataset.dataset)
@@ -481,7 +481,7 @@ def eval(
481481
sample_valid_mask = (
482482
valid_mask[i] if valid_mask is not None else None
483483
)
484-
sample_mf = um.MetricsFactory(n_classes=self.n_classes)
484+
sample_mf = um.SegmentationMetricsFactory(n_classes=self.n_classes)
485485
sample_mf.update(sample_pred, sample_label, sample_valid_mask)
486486
sample_df = um.get_metrics_dataframe(sample_mf, self.ontology)
487487
sample_df.to_csv(

detectionmetrics/models/torch_segmentation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def eval(
530530
)
531531

532532
# Init metrics
533-
metrics_factory = um.MetricsFactory(self.n_classes)
533+
metrics_factory = um.SegmentationMetricsFactory(self.n_classes)
534534

535535
# Evaluation loop
536536
with torch.no_grad():
@@ -585,7 +585,7 @@ def eval(
585585
sample_valid_mask = (
586586
valid_mask[i] if valid_mask is not None else None
587587
)
588-
sample_mf = um.MetricsFactory(n_classes=self.n_classes)
588+
sample_mf = um.SegmentationMetricsFactory(n_classes=self.n_classes)
589589
sample_mf.update(
590590
sample_pred, sample_label, sample_valid_mask
591591
)
@@ -805,7 +805,7 @@ def eval(
805805
)
806806

807807
# Init metrics
808-
metrics_factory = um.MetricsFactory(self.n_classes)
808+
metrics_factory = um.SegmentationMetricsFactory(self.n_classes)
809809

810810
# Evaluation loop
811811
end_th = self.model_cfg.get("end_th", 0.5)
@@ -885,7 +885,7 @@ def eval(
885885
sample_valid_mask = (
886886
valid_mask[i] if valid_mask is not None else None
887887
)
888-
sample_mf = um.MetricsFactory(n_classes=self.n_classes)
888+
sample_mf = um.SegmentationMetricsFactory(n_classes=self.n_classes)
889889
sample_mf.update(
890890
sample_pred, sample_label, sample_valid_mask
891891
)

detectionmetrics/utils/metrics.py renamed to detectionmetrics/utils/segmentation_metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pandas as pd
77

88

9-
class MetricsFactory:
10-
"""'Factory' class to accumulate results and compute metrics
9+
class SegmentationMetricsFactory:
10+
"""'Factory' class to accumulate results and compute metrics for segmentation tasks
1111
1212
:param n_classes: Number of classes to evaluate
1313
:type n_classes: int
@@ -256,12 +256,12 @@ def get_metric_per_name(
256256

257257

258258
def get_metrics_dataframe(
259-
metrics_factory: MetricsFactory, ontology: dict
259+
metrics_factory: SegmentationMetricsFactory, ontology: dict
260260
) -> pd.DataFrame:
261261
"""Build a DataFrame with all metrics (global and per class) plus confusion matrix
262262
263-
:param metrics_factory: Properly updated MetricsFactory object
264-
:type metrics_factory: MetricsFactory
263+
:param metrics_factory: Properly updated SegmentationMetricsFactory object
264+
:type metrics_factory: SegmentationMetricsFactory
265265
:param ontology: Ontology dictionary
266266
:type ontology: dict
267267
:return: DataFrame with all metrics

examples/tutorial_image_segmentation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@
314314
],
315315
"metadata": {
316316
"kernelspec": {
317-
"display_name": "detectionmetrics-cJs_3AVd-py3.10",
317+
"display_name": ".venv",
318318
"language": "python",
319319
"name": "python3"
320320
},
@@ -328,7 +328,7 @@
328328
"name": "python",
329329
"nbconvert_exporter": "python",
330330
"pygments_lexer": "ipython3",
331-
"version": "3.10.13"
331+
"version": "3.10.4"
332332
}
333333
},
334334
"nbformat": 4,

tests/test_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import numpy as np
22
import pytest
3-
from detectionmetrics.utils.metrics import MetricsFactory
3+
from detectionmetrics.utils.metrics import SegmentationMetricsFactory
44

55

66
@pytest.fixture
77
def metrics_factory():
8-
"""Fixture to create a MetricsFactory instance for testing"""
9-
return MetricsFactory(n_classes=3)
8+
"""Fixture to create a SegmentationMetricsFactory instance for testing"""
9+
return SegmentationMetricsFactory(n_classes=3)
1010

1111

1212
def test_update_confusion_matrix(metrics_factory):
@@ -94,7 +94,7 @@ def test_edge_cases(metrics_factory):
9494
with pytest.raises(AssertionError):
9595
metrics_factory.update(pred, gt)
9696

97-
empty_metrics_factory = MetricsFactory(n_classes=3)
97+
empty_metrics_factory = SegmentationMetricsFactory(n_classes=3)
9898

9999
assert np.isnan(empty_metrics_factory.get_precision(per_class=False))
100100
assert np.isnan(empty_metrics_factory.get_recall(per_class=False))

0 commit comments

Comments
 (0)