Skip to content

Commit 70412ad

Browse files
Made all the changes requested
1 parent ce88007 commit 70412ad

18 files changed

+343
-647
lines changed

app.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

detectionmetrics/datasets/coco.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from detectionmetrics.datasets.detection import ImageDetectionDataset
77

88

9-
def build_coco_dataset(annotation_file: str, image_dir: str, coco_obj: Optional[COCO] = None) -> Tuple[pd.DataFrame, dict]:
9+
def build_coco_dataset(
10+
annotation_file: str,
11+
image_dir: str,
12+
coco_obj: Optional[COCO] = None,
13+
split: str = "train",
14+
) -> Tuple[pd.DataFrame, dict]:
1015
"""Build dataset and ontology dictionaries from COCO dataset structure
1116
1217
:param annotation_file: Path to the COCO-format JSON annotation file
@@ -15,41 +20,47 @@ def build_coco_dataset(annotation_file: str, image_dir: str, coco_obj: Optional[
1520
:type image_dir: str
1621
:param coco_obj: Optional pre-loaded COCO object to reuse
1722
:type coco_obj: COCO
23+
:param split: Dataset split name (e.g., "train", "val", "test")
24+
:type split: str
1825
:return: Dataset DataFrame and ontology dictionary
1926
:rtype: Tuple[pd.DataFrame, dict]
2027
"""
2128
# Check that provided paths exist
22-
assert os.path.isfile(annotation_file), f"Annotation file not found: {annotation_file}"
29+
assert os.path.isfile(
30+
annotation_file
31+
), f"Annotation file not found: {annotation_file}"
2332
assert os.path.isdir(image_dir), f"Image directory not found: {image_dir}"
2433

2534
# Load COCO annotations (reuse if provided)
2635
if coco_obj is None:
2736
coco = COCO(annotation_file)
2837
else:
2938
coco = coco_obj
30-
39+
3140
# Build ontology from COCO categories
3241
ontology = {}
3342
for cat in coco.loadCats(coco.getCatIds()):
3443
ontology[cat["name"]] = {
3544
"idx": cat["id"],
3645
# "name": cat["name"],
37-
"rgb": [0, 0, 0] # Placeholder; COCO doesn't define RGB colors
46+
"rgb": [0, 0, 0], # Placeholder; COCO doesn't define RGB colors
3847
}
3948

4049
# Build dataset DataFrame from COCO image IDs
4150
rows = []
4251
for img_id in coco.getImgIds():
4352
img_info = coco.loadImgs(img_id)[0]
44-
rows.append({
45-
"image": img_info["file_name"],
46-
"annotation": str(img_id),
47-
"split": "train" # Default split - could be enhanced to read from COCO
48-
})
49-
53+
rows.append(
54+
{
55+
"image": img_info["file_name"],
56+
"annotation": str(img_id),
57+
"split": split, # Use provided split parameter
58+
}
59+
)
60+
5061
dataset = pd.DataFrame(rows)
5162
dataset.attrs = {"ontology": ontology}
52-
63+
5364
return dataset, ontology
5465

5566

@@ -61,26 +72,36 @@ class CocoDataset(ImageDetectionDataset):
6172
:type annotation_file: str
6273
:param image_dir: Path to the directory containing image files
6374
:type image_dir: str
75+
:param split: Dataset split name (e.g., "train", "val", "test")
76+
:type split: str
6477
"""
65-
def __init__(self, annotation_file: str, image_dir: str):
78+
79+
def __init__(self, annotation_file: str, image_dir: str, split: str = "train"):
6680
# Load COCO object once
6781
self.coco = COCO(annotation_file)
6882
self.image_dir = image_dir
69-
70-
# Build dataset using the same COCO object
71-
dataset, ontology = build_coco_dataset(annotation_file, image_dir, self.coco)
72-
83+
self.split = split
84+
85+
# Build dataset using the same COCO object and split
86+
dataset, ontology = build_coco_dataset(
87+
annotation_file, image_dir, self.coco, split=split
88+
)
89+
7390
super().__init__(dataset=dataset, dataset_dir=image_dir, ontology=ontology)
7491

75-
def read_annotation(self, fname: str) -> Tuple[List[List[float]], List[int], List[int]]:
92+
def read_annotation(
93+
self, fname: str
94+
) -> Tuple[List[List[float]], List[int], List[int]]:
7695
"""Return bounding boxes, labels, and category_ids for a given image ID.
7796
7897
:param fname: str (image_id in string form)
7998
:return: Tuple of (boxes, labels, category_ids)
8099
"""
81100
# Extract image ID (fname might be a path or ID string)
82101
try:
83-
image_id = int(os.path.basename(fname)) # handles both '123' and '/path/to/123'
102+
image_id = int(
103+
os.path.basename(fname)
104+
) # handles both '123' and '/path/to/123'
84105
except ValueError:
85106
raise ValueError(f"Invalid annotation ID: {fname}")
86107

@@ -99,4 +120,3 @@ def read_annotation(self, fname: str) -> Tuple[List[List[float]], List[int], Lis
99120
category_ids.append(ann["category_id"])
100121

101122
return boxes, labels, category_ids
102-

detectionmetrics/datasets/detection.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import detectionmetrics.utils.io as uio
1414
import detectionmetrics.utils.conversion as uc
1515

16+
1617
class DetectionDataset(PerceptionDataset):
1718
"""Abstract perception detection dataset class."""
1819

@@ -23,7 +24,7 @@ def read_annotation(self, fname: str):
2324
:param fname: Annotation file name
2425
"""
2526
raise NotImplementedError
26-
27+
2728
def get_label_count(self, splits: Optional[List[str]] = None):
2829
"""Count detection labels per class for given splits.
2930
@@ -40,7 +41,9 @@ def get_label_count(self, splits: Optional[List[str]] = None):
4041
for annotation_file in tqdm(df["annotation"], desc="Counting labels"):
4142
annots = self.read_annotation(annotation_file)
4243
for annot in annots:
43-
class_idx = annot["category_id"] #Should override the key category_id if needed in specific dataset class
44+
class_idx = annot[
45+
"category_id"
46+
] # Should override the key category_id if needed in specific dataset class
4447
label_count[class_idx] += 1
4548

4649
return label_count
@@ -58,7 +61,7 @@ def make_fname_global(self):
5861
self.dataset["annotation"] = self.dataset["annotation"].apply(
5962
lambda x: os.path.join(self.dataset_dir, x) if x is not None else None
6063
)
61-
self.dataset_dir = None
64+
self.dataset_dir = None
6265

6366
def read_annotation(self, fname: str):
6467
"""Read detection annotation from a file.
@@ -75,7 +78,13 @@ def read_annotation(self, fname: str):
7578
class LiDARDetectionDataset(DetectionDataset):
7679
"""LiDAR detection dataset class."""
7780

78-
def __init__(self, dataset: pd.DataFrame, dataset_dir: str, ontology: dict, is_kitti_format: bool = True):
81+
def __init__(
82+
self,
83+
dataset: pd.DataFrame,
84+
dataset_dir: str,
85+
ontology: dict,
86+
is_kitti_format: bool = True,
87+
):
7988
super().__init__(dataset, dataset_dir, ontology)
8089
self.is_kitti_format = is_kitti_format
8190

@@ -98,4 +107,4 @@ def read_annotation(self, fname: str):
98107
:return: Parsed annotations (e.g., list of dicts)
99108
"""
100109
# TODO Implement format specific parsing
101-
raise NotImplementedError("Implement LiDAR detection annotation reading")
110+
raise NotImplementedError("Implement LiDAR detection annotation reading")

detectionmetrics/datasets/perception.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def append(self, new_dataset: Self):
5757
self.ontology[class_name]["idx"]
5858
== new_dataset.ontology[class_name]["idx"]
5959
), "Ontologies don't match"
60-
if "rgb" in self.ontology[class_name] and "rgb" in new_dataset.ontology[class_name]:
60+
if (
61+
"rgb" in self.ontology[class_name]
62+
and "rgb" in new_dataset.ontology[class_name]
63+
):
6164
assert (
6265
self.ontology[class_name]["rgb"]
6366
== new_dataset.ontology[class_name]["rgb"]

detectionmetrics/datasets/rellis3d.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def build_dataset(
7373
return dataset, ontology
7474

7575

76-
class Rellis3DImageSegmentationDataset(dm_segmentation_dataset.ImageSegmentationDataset):
76+
class Rellis3DImageSegmentationDataset(
77+
dm_segmentation_dataset.ImageSegmentationDataset
78+
):
7779
"""Specific class for Rellis3D-styled image segmentation datasets. All data can
7880
be downloaded from the official repo (https://github.com/unmannedlab/RELLIS-3D):
7981
images -> https://drive.google.com/file/d/1F3Leu0H_m6aPVpZITragfreO_SGtL2yV
@@ -105,7 +107,9 @@ def __init__(self, dataset_dir: str, split_dir: str, ontology_fname: str):
105107
super().__init__(dataset, dataset_dir, ontology)
106108

107109

108-
class Rellis3DLiDARSegmentationDataset(dm_segmentation_dataset.LiDARSegmentationDataset):
110+
class Rellis3DLiDARSegmentationDataset(
111+
dm_segmentation_dataset.LiDARSegmentationDataset
112+
):
109113
"""Specific class for Rellis3D-styled LiDAR segmentation datasets. All data can
110114
be downloaded from the official repo (https://github.com/unmannedlab/RELLIS-3D):
111115
points -> https://drive.google.com/file/d/1lDSVRf_kZrD0zHHMsKJ0V1GN9QATR4wH

detectionmetrics/datasets/segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import detectionmetrics.utils.io as uio
1414
import detectionmetrics.utils.conversion as uc
1515

16+
1617
class SegmentationDataset(PerceptionDataset):
1718
"""Abstract perception dataset class."""
1819

@@ -383,4 +384,3 @@ def read_label(fname: str) -> Tuple[np.ndarray, np.ndarray]:
383384
semantic_label = label & 0xFFFF
384385
instance_label = label >> 16
385386
return semantic_label.astype(np.int32), instance_label.astype(np.int32)
386-

detectionmetrics/datasets/wildscenes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def build_dataset(dataset_dir: str, split_fnames: dict) -> Tuple[dict, dict]:
123123
return dataset, ontology
124124

125125

126-
class WildscenesImageSegmentationDataset(dm_segmentation_dataset.ImageSegmentationDataset):
126+
class WildscenesImageSegmentationDataset(
127+
dm_segmentation_dataset.ImageSegmentationDataset
128+
):
127129
"""Specific class for Wildscenes-styled image segmentation datasets. All data can
128130
be downloaded from the official repo (https://github.com/unmannedlab/RELLIS-3D):
129131
dataset -> https://data.csiro.au/collection/csiro:61541

detectionmetrics/models/detection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import detectionmetrics.utils.conversion as uc
1313
import detectionmetrics.utils.io as uio
1414

15+
1516
class DetectionModel(PerceptionModel):
1617
"""Parent detection model class
1718
@@ -38,9 +39,7 @@ def __init__(
3839
super().__init__(model, model_type, model_cfg, ontology_fname, model_fname)
3940

4041
@abstractmethod
41-
def inference(
42-
self, data: Union[np.ndarray, Image.Image]
43-
) -> List[dict]:
42+
def inference(self, data: Union[np.ndarray, Image.Image]) -> List[dict]:
4443
"""Perform inference for a single input (image or point cloud)
4544
4645
:param data: Input image or LiDAR point cloud
@@ -76,6 +75,7 @@ def eval(
7675
"""
7776
raise NotImplementedError
7877

78+
7979
class ImageDetectionModel(DetectionModel):
8080
"""Parent image detection model class
8181
@@ -138,6 +138,7 @@ def eval(
138138
"""
139139
raise NotImplementedError
140140

141+
141142
class LiDARDetectionModel(DetectionModel):
142143
"""Parent LiDAR detection model class
143144
@@ -198,4 +199,4 @@ def eval(
198199
:return: DataFrame containing evaluation metrics
199200
:rtype: pd.DataFrame
200201
"""
201-
raise NotImplementedError
202+
raise NotImplementedError

detectionmetrics/models/segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import detectionmetrics.utils.conversion as uc
1313
import detectionmetrics.utils.io as uio
1414

15+
1516
class SegmentationModel(PerceptionModel):
1617
"""Parent segmentation model class
1718
@@ -201,4 +202,3 @@ def eval(
201202
:rtype: pd.DataFrame
202203
"""
203204
raise NotImplementedError
204-

0 commit comments

Comments
 (0)