Skip to content

Commit 3c848bc

Browse files
Black formatting done
1 parent b4b61d2 commit 3c848bc

File tree

7 files changed

+485
-282
lines changed

7 files changed

+485
-282
lines changed

app.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tabs.inference import inference_tab
77
from tabs.evaluator import evaluator_tab
88

9+
910
def browse_folder():
1011
"""
1112
Opens a native folder selection dialog and returns the selected folder path.
@@ -15,32 +16,47 @@ def browse_folder():
1516
try:
1617
if sys.platform.startswith("win"):
1718
script = (
18-
'Add-Type -AssemblyName System.windows.forms;'
19-
'$f=New-Object System.Windows.Forms.FolderBrowserDialog;'
19+
"Add-Type -AssemblyName System.windows.forms;"
20+
"$f=New-Object System.Windows.Forms.FolderBrowserDialog;"
2021
'if($f.ShowDialog() -eq "OK"){Write-Output $f.SelectedPath}'
2122
)
2223
result = subprocess.run(
2324
["powershell", "-NoProfile", "-Command", script],
24-
capture_output=True, text=True, timeout=30
25+
capture_output=True,
26+
text=True,
27+
timeout=30,
2528
)
2629
folder = result.stdout.strip()
2730
return folder if folder else None
2831
elif sys.platform == "darwin":
29-
script = 'POSIX path of (choose folder with prompt "Select dataset folder:")'
32+
script = (
33+
'POSIX path of (choose folder with prompt "Select dataset folder:")'
34+
)
3035
result = subprocess.run(
31-
["osascript", "-e", script],
32-
capture_output=True, text=True, timeout=30
36+
["osascript", "-e", script], capture_output=True, text=True, timeout=30
3337
)
3438
folder = result.stdout.strip()
3539
return folder if folder else None
3640
else:
3741
# Linux: try zenity, then kdialog
3842
for cmd in [
39-
["zenity", "--file-selection", "--directory", "--title=Select dataset folder"],
40-
["kdialog", "--getexistingdirectory", "--title", "Select dataset folder"]
43+
[
44+
"zenity",
45+
"--file-selection",
46+
"--directory",
47+
"--title=Select dataset folder",
48+
],
49+
[
50+
"kdialog",
51+
"--getexistingdirectory",
52+
"--title",
53+
"Select dataset folder",
54+
],
4155
]:
4256
try:
43-
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
57+
result = subprocess.run(
58+
cmd, capture_output=True, text=True, timeout=30
59+
)
4460
folder = result.stdout.strip()
4561
if folder:
4662
return folder
@@ -50,6 +66,7 @@ def browse_folder():
5066
except Exception:
5167
return None
5268

69+
5370
st.set_page_config(page_title="DetectionMetrics", layout="wide")
5471

5572
# st.title("DetectionMetrics")
@@ -91,7 +108,7 @@ def browse_folder():
91108
["train", "val"],
92109
key="split_selectbox",
93110
)
94-
111+
95112
# Second row: Path and Browse button
96113
col1, col2 = st.columns([3, 1])
97114
with col1:
@@ -101,15 +118,17 @@ def browse_folder():
101118
key="dataset_path_input",
102119
)
103120
with col2:
104-
st.markdown("<div style='margin-bottom: 1.75rem;'></div>", unsafe_allow_html=True)
121+
st.markdown(
122+
"<div style='margin-bottom: 1.75rem;'></div>", unsafe_allow_html=True
123+
)
105124
if st.button("Browse", key="browse_button"):
106125
folder = browse_folder()
107126
if folder and os.path.isdir(folder):
108127
st.session_state["dataset_path"] = folder
109128
st.rerun()
110129
elif folder is not None:
111130
st.warning("Selected path is not a valid folder.")
112-
131+
113132
if dataset_path_input != st.session_state.get("dataset_path", ""):
114133
st.session_state["dataset_path"] = dataset_path_input
115134

@@ -132,7 +151,10 @@ def browse_folder():
132151
key="config_option",
133152
horizontal=True,
134153
)
135-
if st.session_state.get("config_option", "Manual Configuration") == "Upload Config File":
154+
if (
155+
st.session_state.get("config_option", "Manual Configuration")
156+
== "Upload Config File"
157+
):
136158
st.file_uploader(
137159
"Configuration File (.json)",
138160
type=["json"],
@@ -190,14 +212,13 @@ def browse_folder():
190212
value=st.session_state.get("evaluation_step", 10),
191213
step=1,
192214
key="evaluation_step",
193-
help="Update UI with intermediate metrics every N images (0 = disable intermediate updates)"
215+
help="Update UI with intermediate metrics every N images (0 = disable intermediate updates)",
194216
)
195217

196218
# Load model action in sidebar
197219
from detectionmetrics.models.torch_detection import TorchImageDetectionModel
198220
import json, tempfile
199221

200-
201222
load_model_btn = st.button(
202223
"Load Model",
203224
type="primary",
@@ -209,8 +230,14 @@ def browse_folder():
209230
if load_model_btn:
210231
model_file = st.session_state.get("model_file")
211232
ontology_file = st.session_state.get("ontology_file")
212-
config_option = st.session_state.get("config_option", "Manual Configuration")
213-
config_file = st.session_state.get("config_file") if config_option == "Upload Config File" else None
233+
config_option = st.session_state.get(
234+
"config_option", "Manual Configuration"
235+
)
236+
config_file = (
237+
st.session_state.get("config_file")
238+
if config_option == "Upload Config File"
239+
else None
240+
)
214241

215242
# Prepare configuration
216243
config_data = None
@@ -219,18 +246,22 @@ def browse_folder():
219246
if config_option == "Upload Config File":
220247
if config_file is not None:
221248
config_data = json.load(config_file)
222-
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_cfg:
249+
with tempfile.NamedTemporaryFile(
250+
delete=False, suffix=".json", mode="w"
251+
) as tmp_cfg:
223252
json.dump(config_data, tmp_cfg)
224253
config_path = tmp_cfg.name
225254
else:
226255
st.error("Please upload a configuration file")
227256
else:
228-
confidence_threshold = float(st.session_state.get('confidence_threshold', 0.5))
229-
nms_threshold = float(st.session_state.get('nms_threshold', 0.5))
230-
max_detections = int(st.session_state.get('max_detections', 100))
231-
device = st.session_state.get('device', 'cpu')
232-
batch_size = int(st.session_state.get('batch_size', 1))
233-
evaluation_step = int(st.session_state.get('evaluation_step', 5))
257+
confidence_threshold = float(
258+
st.session_state.get("confidence_threshold", 0.5)
259+
)
260+
nms_threshold = float(st.session_state.get("nms_threshold", 0.5))
261+
max_detections = int(st.session_state.get("max_detections", 100))
262+
device = st.session_state.get("device", "cpu")
263+
batch_size = int(st.session_state.get("batch_size", 1))
264+
evaluation_step = int(st.session_state.get("evaluation_step", 5))
234265
config_data = {
235266
"confidence_threshold": confidence_threshold,
236267
"nms_threshold": nms_threshold,
@@ -239,7 +270,9 @@ def browse_folder():
239270
"batch_size": batch_size,
240271
"evaluation_step": evaluation_step,
241272
}
242-
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_cfg:
273+
with tempfile.NamedTemporaryFile(
274+
delete=False, suffix=".json", mode="w"
275+
) as tmp_cfg:
243276
json.dump(config_data, tmp_cfg)
244277
config_path = tmp_cfg.name
245278
except Exception as e:
@@ -257,7 +290,9 @@ def browse_folder():
257290
# Persist ontology to temp file
258291
try:
259292
ontology_data = json.load(ontology_file)
260-
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_ont:
293+
with tempfile.NamedTemporaryFile(
294+
delete=False, suffix=".json", mode="w"
295+
) as tmp_ont:
261296
json.dump(ontology_data, tmp_ont)
262297
ontology_path = tmp_ont.name
263298
except Exception as e:
@@ -266,7 +301,9 @@ def browse_folder():
266301

267302
# Persist model to temp file
268303
try:
269-
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt', mode='wb') as tmp_model:
304+
with tempfile.NamedTemporaryFile(
305+
delete=False, suffix=".pt", mode="wb"
306+
) as tmp_model:
270307
tmp_model.write(model_file.read())
271308
model_temp_path = tmp_model.name
272309
except Exception as e:
@@ -279,7 +316,7 @@ def browse_folder():
279316
model=model_temp_path,
280317
model_cfg=config_path,
281318
ontology_fname=ontology_path,
282-
device=st.session_state.get('device', 'cpu'),
319+
device=st.session_state.get("device", "cpu"),
283320
)
284321
st.session_state.detection_model = model
285322
st.session_state.detection_model_loaded = True
@@ -297,4 +334,4 @@ def browse_folder():
297334
with tab2:
298335
inference_tab()
299336
with tab3:
300-
evaluator_tab()
337+
evaluator_tab()

detectionmetrics/datasets/coco.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,17 @@ def read_annotation(
106106
image_id = int(os.path.basename(fname))
107107
except ValueError:
108108
raise ValueError(f"Invalid annotation ID: {fname}")
109-
109+
110110
# Use COCO's efficient indexing to get annotations for this image
111111
# getAnnIds() and loadAnns() are very fast due to COCO's internal indexing
112112
ann_ids = self.coco.getAnnIds(imgIds=image_id)
113113
anns = self.coco.loadAnns(ann_ids)
114-
114+
115115
boxes, labels, category_ids = [], [], []
116116
for ann in anns:
117117
x, y, w, h = ann["bbox"]
118118
boxes.append([x, y, x + w, y + h])
119119
labels.append(ann["category_id"])
120120
category_ids.append(ann["category_id"])
121-
121+
122122
return boxes, labels, category_ids

detectionmetrics/models/torch_detection.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,14 @@ def eval(
361361
num_workers = 0
362362
else:
363363
num_workers = self.model_cfg.get("num_workers")
364-
364+
365365
dataloader = DataLoader(
366366
dataset,
367367
batch_size=self.model_cfg.get("batch_size", 1),
368368
num_workers=num_workers,
369-
collate_fn=lambda batch: tuple(zip(*batch)), # handles variable-size targets
369+
collate_fn=lambda batch: tuple(
370+
zip(*batch)
371+
), # handles variable-size targets
370372
)
371373

372374
# Get iou_threshold from model config, default to 0.5 if not present
@@ -394,7 +396,7 @@ def eval(
394396
iterator = pbar
395397
else:
396398
iterator = dataloader
397-
399+
398400
for image_ids, images, targets in iterator:
399401
# Defensive check for empty images
400402
if not images or any(img.numel() == 0 for img in images):
@@ -481,29 +483,33 @@ def eval(
481483
predictions_outdir, f"{sample_id}_metrics.csv"
482484
)
483485
)
484-
486+
485487
processed_samples += 1
486-
488+
487489
# Call progress callback if provided
488490
if progress_callback is not None:
489491
progress_callback(processed_samples, total_samples)
490-
492+
491493
# Call metrics callback if provided and evaluation_step is reached
492-
if (metrics_callback is not None and
493-
evaluation_step is not None and
494-
processed_samples % evaluation_step == 0):
494+
if (
495+
metrics_callback is not None
496+
and evaluation_step is not None
497+
and processed_samples % evaluation_step == 0
498+
):
495499
# Get intermediate metrics
496-
intermediate_metrics = metrics_factory.get_metrics_dataframe(self.ontology)
497-
metrics_callback(intermediate_metrics, processed_samples, total_samples)
500+
intermediate_metrics = metrics_factory.get_metrics_dataframe(
501+
self.ontology
502+
)
503+
metrics_callback(
504+
intermediate_metrics, processed_samples, total_samples
505+
)
498506

499507
# Return both the DataFrame and the metrics factory for access to precision-recall curves
500508
return {
501509
"metrics_df": metrics_factory.get_metrics_dataframe(self.ontology),
502-
"metrics_factory": metrics_factory
510+
"metrics_factory": metrics_factory,
503511
}
504512

505-
506-
507513
def get_computational_cost(
508514
self, image_size: Tuple[int], runs: int = 30, warm_up_runs: int = 5
509515
) -> dict:

0 commit comments

Comments
 (0)