66from detectionmetrics .datasets .detection import ImageDetectionDataset
77
88
9- def build_coco_dataset (annotation_file : str , image_dir : str , coco_obj : Optional [COCO ] = None ) -> Tuple [pd .DataFrame , dict ]:
9+ def build_coco_dataset (
10+ annotation_file : str ,
11+ image_dir : str ,
12+ coco_obj : Optional [COCO ] = None ,
13+ split : str = "train" ,
14+ ) -> Tuple [pd .DataFrame , dict ]:
1015 """Build dataset and ontology dictionaries from COCO dataset structure
1116
1217 :param annotation_file: Path to the COCO-format JSON annotation file
@@ -15,41 +20,47 @@ def build_coco_dataset(annotation_file: str, image_dir: str, coco_obj: Optional[
1520 :type image_dir: str
1621 :param coco_obj: Optional pre-loaded COCO object to reuse
1722 :type coco_obj: COCO
23+ :param split: Dataset split name (e.g., "train", "val", "test")
24+ :type split: str
1825 :return: Dataset DataFrame and ontology dictionary
1926 :rtype: Tuple[pd.DataFrame, dict]
2027 """
2128 # Check that provided paths exist
22- assert os .path .isfile (annotation_file ), f"Annotation file not found: { annotation_file } "
29+ assert os .path .isfile (
30+ annotation_file
31+ ), f"Annotation file not found: { annotation_file } "
2332 assert os .path .isdir (image_dir ), f"Image directory not found: { image_dir } "
2433
2534 # Load COCO annotations (reuse if provided)
2635 if coco_obj is None :
2736 coco = COCO (annotation_file )
2837 else :
2938 coco = coco_obj
30-
39+
3140 # Build ontology from COCO categories
3241 ontology = {}
3342 for cat in coco .loadCats (coco .getCatIds ()):
3443 ontology [cat ["name" ]] = {
3544 "idx" : cat ["id" ],
3645 # "name": cat["name"],
37- "rgb" : [0 , 0 , 0 ] # Placeholder; COCO doesn't define RGB colors
46+ "rgb" : [0 , 0 , 0 ], # Placeholder; COCO doesn't define RGB colors
3847 }
3948
4049 # Build dataset DataFrame from COCO image IDs
4150 rows = []
4251 for img_id in coco .getImgIds ():
4352 img_info = coco .loadImgs (img_id )[0 ]
44- rows .append ({
45- "image" : img_info ["file_name" ],
46- "annotation" : str (img_id ),
47- "split" : "train" # Default split - could be enhanced to read from COCO
48- })
49-
53+ rows .append (
54+ {
55+ "image" : img_info ["file_name" ],
56+ "annotation" : str (img_id ),
57+ "split" : split , # Use provided split parameter
58+ }
59+ )
60+
5061 dataset = pd .DataFrame (rows )
5162 dataset .attrs = {"ontology" : ontology }
52-
63+
5364 return dataset , ontology
5465
5566
@@ -61,26 +72,36 @@ class CocoDataset(ImageDetectionDataset):
6172 :type annotation_file: str
6273 :param image_dir: Path to the directory containing image files
6374 :type image_dir: str
75+ :param split: Dataset split name (e.g., "train", "val", "test")
76+ :type split: str
6477 """
65- def __init__ (self , annotation_file : str , image_dir : str ):
78+
79+ def __init__ (self , annotation_file : str , image_dir : str , split : str = "train" ):
6680 # Load COCO object once
6781 self .coco = COCO (annotation_file )
6882 self .image_dir = image_dir
69-
70- # Build dataset using the same COCO object
71- dataset , ontology = build_coco_dataset (annotation_file , image_dir , self .coco )
72-
83+ self .split = split
84+
85+ # Build dataset using the same COCO object and split
86+ dataset , ontology = build_coco_dataset (
87+ annotation_file , image_dir , self .coco , split = split
88+ )
89+
7390 super ().__init__ (dataset = dataset , dataset_dir = image_dir , ontology = ontology )
7491
75- def read_annotation (self , fname : str ) -> Tuple [List [List [float ]], List [int ], List [int ]]:
92+ def read_annotation (
93+ self , fname : str
94+ ) -> Tuple [List [List [float ]], List [int ], List [int ]]:
7695 """Return bounding boxes, labels, and category_ids for a given image ID.
7796
7897 :param fname: str (image_id in string form)
7998 :return: Tuple of (boxes, labels, category_ids)
8099 """
81100 # Extract image ID (fname might be a path or ID string)
82101 try :
83- image_id = int (os .path .basename (fname )) # handles both '123' and '/path/to/123'
102+ image_id = int (
103+ os .path .basename (fname )
104+ ) # handles both '123' and '/path/to/123'
84105 except ValueError :
85106 raise ValueError (f"Invalid annotation ID: { fname } " )
86107
@@ -99,4 +120,3 @@ def read_annotation(self, fname: str) -> Tuple[List[List[float]], List[int], Lis
99120 category_ids .append (ann ["category_id" ])
100121
101122 return boxes , labels , category_ids
102-
0 commit comments