Skip to content

Commit 5790373

Browse files
Reframed the UI with input sidebar
1 parent cc82119 commit 5790373

File tree

4 files changed

+541
-557
lines changed

4 files changed

+541
-557
lines changed

app.py

Lines changed: 273 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,288 @@
11
import streamlit as st
2+
import os
3+
import sys
4+
import subprocess
25
from tabs.dataset_viewer import dataset_viewer_tab
36
from tabs.inference import inference_tab
47
from tabs.evaluator import evaluator_tab
58

9+
def browse_folder():
10+
"""
11+
Opens a native folder selection dialog and returns the selected folder path.
12+
Works on Windows, macOS, and Linux (with zenity or kdialog).
13+
Returns None if cancelled or error.
14+
"""
15+
try:
16+
if sys.platform.startswith("win"):
17+
script = (
18+
'Add-Type -AssemblyName System.windows.forms;'
19+
'$f=New-Object System.Windows.Forms.FolderBrowserDialog;'
20+
'if($f.ShowDialog() -eq "OK"){Write-Output $f.SelectedPath}'
21+
)
22+
result = subprocess.run(
23+
["powershell", "-NoProfile", "-Command", script],
24+
capture_output=True, text=True, timeout=30
25+
)
26+
folder = result.stdout.strip()
27+
return folder if folder else None
28+
elif sys.platform == "darwin":
29+
script = 'POSIX path of (choose folder with prompt "Select dataset folder:")'
30+
result = subprocess.run(
31+
["osascript", "-e", script],
32+
capture_output=True, text=True, timeout=30
33+
)
34+
folder = result.stdout.strip()
35+
return folder if folder else None
36+
else:
37+
# Linux: try zenity, then kdialog
38+
for cmd in [
39+
["zenity", "--file-selection", "--directory", "--title=Select dataset folder"],
40+
["kdialog", "--getexistingdirectory", "--title", "Select dataset folder"]
41+
]:
42+
try:
43+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
44+
folder = result.stdout.strip()
45+
if folder:
46+
return folder
47+
except Exception:
48+
continue
49+
return None
50+
except Exception:
51+
return None
52+
653
st.set_page_config(page_title="DetectionMetrics", layout="wide")
754

855
# st.title("DetectionMetrics")
956

1057
PAGES = {
1158
"Dataset Viewer": dataset_viewer_tab,
1259
"Inference": inference_tab,
13-
"Evaluator": evaluator_tab
60+
"Evaluator": evaluator_tab,
1461
}
1562

16-
page = st.sidebar.radio("DetectionMetrics", list(PAGES.keys()))
63+
# Initialize commonly used session state keys
64+
st.session_state.setdefault("dataset_path", "")
65+
st.session_state.setdefault("dataset_type_selectbox", "Coco")
66+
st.session_state.setdefault("split_selectbox", "val")
67+
st.session_state.setdefault("config_option", "Manual Configuration")
68+
st.session_state.setdefault("confidence_threshold", 0.5)
69+
st.session_state.setdefault("nms_threshold", 0.5)
70+
st.session_state.setdefault("max_detections", 100)
71+
st.session_state.setdefault("device", "cpu")
72+
st.session_state.setdefault("batch_size", 1)
73+
st.session_state.setdefault("detection_model", None)
74+
st.session_state.setdefault("detection_model_loaded", False)
75+
76+
# Sidebar: Dataset Inputs
77+
with st.sidebar:
78+
with st.expander("Dataset Inputs", expanded=True):
79+
# First row: Type and Split
80+
col1, col2 = st.columns(2)
81+
with col1:
82+
st.selectbox(
83+
"Type",
84+
["Coco", "Custom"],
85+
key="dataset_type_selectbox",
86+
)
87+
with col2:
88+
st.selectbox(
89+
"Split",
90+
["train", "val"],
91+
key="split_selectbox",
92+
)
93+
94+
# Second row: Path and Browse button
95+
col1, col2 = st.columns([3, 1])
96+
with col1:
97+
dataset_path_input = st.text_input(
98+
"Dataset Folder Path",
99+
value=st.session_state.get("dataset_path", ""),
100+
key="dataset_path_input",
101+
)
102+
with col2:
103+
st.markdown("<div style='margin-bottom: 1.75rem;'></div>", unsafe_allow_html=True)
104+
if st.button("Browse", key="browse_button"):
105+
folder = browse_folder()
106+
if folder and os.path.isdir(folder):
107+
st.session_state["dataset_path"] = folder
108+
st.rerun()
109+
elif folder is not None:
110+
st.warning("Selected path is not a valid folder.")
111+
112+
if dataset_path_input != st.session_state.get("dataset_path", ""):
113+
st.session_state["dataset_path"] = dataset_path_input
114+
115+
with st.expander("Model Inputs", expanded=False):
116+
st.file_uploader(
117+
"Model File (.pt, .onnx, .h5, .pb, .pth)",
118+
type=["pt", "onnx", "h5", "pb", "pth"],
119+
key="model_file",
120+
help="Upload your trained model file.",
121+
)
122+
st.file_uploader(
123+
"Ontology File (.json)",
124+
type=["json"],
125+
key="ontology_file",
126+
help="Upload a JSON file with class labels.",
127+
)
128+
st.radio(
129+
"Configuration Method:",
130+
["Manual Configuration", "Upload Config File"],
131+
key="config_option",
132+
horizontal=True,
133+
)
134+
if st.session_state.get("config_option", "Manual Configuration") == "Upload Config File":
135+
st.file_uploader(
136+
"Configuration File (.json)",
137+
type=["json"],
138+
key="config_file",
139+
help="Upload a JSON configuration file.",
140+
)
141+
else:
142+
col1, col2 = st.columns(2)
143+
with col1:
144+
st.slider(
145+
"Confidence Threshold",
146+
min_value=0.0,
147+
max_value=1.0,
148+
value=st.session_state.get("confidence_threshold", 0.5),
149+
step=0.01,
150+
key="confidence_threshold",
151+
help="Minimum confidence score for detections",
152+
)
153+
st.slider(
154+
"NMS Threshold",
155+
min_value=0.0,
156+
max_value=1.0,
157+
value=st.session_state.get("nms_threshold", 0.5),
158+
step=0.01,
159+
key="nms_threshold",
160+
help="Non-maximum suppression threshold",
161+
)
162+
st.number_input(
163+
"Max Detections/Image",
164+
min_value=1,
165+
max_value=1000,
166+
value=st.session_state.get("max_detections", 100),
167+
step=1,
168+
key="max_detections",
169+
)
170+
with col2:
171+
st.selectbox(
172+
"Device",
173+
["cpu", "gpu"],
174+
index=0 if st.session_state.get("device", "cpu") == "cpu" else 1,
175+
key="device",
176+
)
177+
st.number_input(
178+
"Batch Size",
179+
min_value=1,
180+
max_value=256,
181+
value=st.session_state.get("batch_size", 1),
182+
step=1,
183+
key="batch_size",
184+
)
185+
186+
# Load model action in sidebar
187+
from detectionmetrics.models.torch_detection import TorchImageDetectionModel
188+
import json, tempfile
189+
190+
191+
load_model_btn = st.button(
192+
"Load Model",
193+
type="primary",
194+
use_container_width=True,
195+
help="Load and save the model for use in the Inference tab",
196+
key="sidebar_load_model_btn",
197+
)
198+
199+
if load_model_btn:
200+
model_file = st.session_state.get("model_file")
201+
ontology_file = st.session_state.get("ontology_file")
202+
config_option = st.session_state.get("config_option", "Manual Configuration")
203+
config_file = st.session_state.get("config_file") if config_option == "Upload Config File" else None
204+
205+
# Prepare configuration
206+
config_data = None
207+
config_path = None
208+
try:
209+
if config_option == "Upload Config File":
210+
if config_file is not None:
211+
config_data = json.load(config_file)
212+
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_cfg:
213+
json.dump(config_data, tmp_cfg)
214+
config_path = tmp_cfg.name
215+
else:
216+
st.error("Please upload a configuration file")
217+
else:
218+
confidence_threshold = float(st.session_state.get('confidence_threshold', 0.5))
219+
nms_threshold = float(st.session_state.get('nms_threshold', 0.5))
220+
max_detections = int(st.session_state.get('max_detections', 100))
221+
device = st.session_state.get('device', 'cpu')
222+
batch_size = int(st.session_state.get('batch_size', 1))
223+
config_data = {
224+
"confidence_threshold": confidence_threshold,
225+
"nms_threshold": nms_threshold,
226+
"max_detections_per_image": max_detections,
227+
"device": device,
228+
"batch_size": batch_size,
229+
}
230+
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_cfg:
231+
json.dump(config_data, tmp_cfg)
232+
config_path = tmp_cfg.name
233+
except Exception as e:
234+
st.error(f"Failed to prepare configuration: {e}")
235+
config_path = None
236+
237+
if model_file is None:
238+
st.error("Please upload a model file")
239+
elif config_path is None:
240+
st.error("Please provide a valid model configuration")
241+
elif ontology_file is None:
242+
st.error("Please upload an ontology file")
243+
else:
244+
with st.spinner("Loading model..."):
245+
# Persist ontology to temp file
246+
try:
247+
ontology_data = json.load(ontology_file)
248+
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_ont:
249+
json.dump(ontology_data, tmp_ont)
250+
ontology_path = tmp_ont.name
251+
except Exception as e:
252+
st.error(f"Failed to load ontology: {e}")
253+
ontology_path = None
254+
255+
# Persist model to temp file
256+
try:
257+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt', mode='wb') as tmp_model:
258+
tmp_model.write(model_file.read())
259+
model_temp_path = tmp_model.name
260+
except Exception as e:
261+
st.error(f"Failed to save model file: {e}")
262+
model_temp_path = None
263+
264+
if ontology_path and model_temp_path:
265+
try:
266+
model = TorchImageDetectionModel(
267+
model=model_temp_path,
268+
model_cfg=config_path,
269+
ontology_fname=ontology_path,
270+
device=st.session_state.get('device', 'cpu'),
271+
)
272+
st.session_state.detection_model = model
273+
st.session_state.detection_model_loaded = True
274+
st.success("Model loaded and saved for inference")
275+
except Exception as e:
276+
st.session_state.detection_model = None
277+
st.session_state.detection_model_loaded = False
278+
st.error(f"Failed to load model: {e}")
279+
280+
# Main content area with horizontal tabs
281+
tab1, tab2, tab3 = st.tabs(["Dataset Viewer", "Inference", "Evaluator"])
17282

18-
PAGES[page]()
283+
with tab1:
284+
dataset_viewer_tab()
285+
with tab2:
286+
inference_tab()
287+
with tab3:
288+
evaluator_tab()

0 commit comments

Comments
 (0)