Skip to content

Commit 62573f8

Browse files
Added live UI updates
1 parent 5790373 commit 62573f8

File tree

3 files changed

+128
-42
lines changed

3 files changed

+128
-42
lines changed

app.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def browse_folder():
7070
st.session_state.setdefault("max_detections", 100)
7171
st.session_state.setdefault("device", "cpu")
7272
st.session_state.setdefault("batch_size", 1)
73+
st.session_state.setdefault("evaluation_step", 5)
7374
st.session_state.setdefault("detection_model", None)
7475
st.session_state.setdefault("detection_model_loaded", False)
7576

@@ -182,6 +183,15 @@ def browse_folder():
182183
step=1,
183184
key="batch_size",
184185
)
186+
st.number_input(
187+
"Evaluation Step",
188+
min_value=0,
189+
max_value=1000,
190+
value=st.session_state.get("evaluation_step", 10),
191+
step=1,
192+
key="evaluation_step",
193+
help="Update UI with intermediate metrics every N images (0 = disable intermediate updates)"
194+
)
185195

186196
# Load model action in sidebar
187197
from detectionmetrics.models.torch_detection import TorchImageDetectionModel
@@ -220,12 +230,14 @@ def browse_folder():
220230
max_detections = int(st.session_state.get('max_detections', 100))
221231
device = st.session_state.get('device', 'cpu')
222232
batch_size = int(st.session_state.get('batch_size', 1))
233+
evaluation_step = int(st.session_state.get('evaluation_step', 5))
223234
config_data = {
224235
"confidence_threshold": confidence_threshold,
225236
"nms_threshold": nms_threshold,
226237
"max_detections_per_image": max_detections,
227238
"device": device,
228239
"batch_size": batch_size,
240+
"evaluation_step": evaluation_step,
229241
}
230242
with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as tmp_cfg:
231243
json.dump(config_data, tmp_cfg)

detectionmetrics/models/torch_detection.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def eval(
315315
predictions_outdir: Optional[str] = None,
316316
results_per_sample: bool = False,
317317
progress_callback=None,
318+
metrics_callback=None,
318319
) -> pd.DataFrame:
319320
"""Evaluate model over a detection dataset and compute metrics
320321
@@ -330,6 +331,8 @@ def eval(
330331
:type results_per_sample: bool
331332
:param progress_callback: Optional callback function for progress updates in Streamlit UI
332333
:type progress_callback: Optional[Callable[[int, int], None]]
334+
:param metrics_callback: Optional callback function for intermediate metrics updates in Streamlit UI
335+
:type metrics_callback: Optional[Callable[[pd.DataFrame, int, int], None]]
333336
:return: DataFrame containing evaluation results
334337
:rtype: pd.DataFrame
335338
"""
@@ -353,16 +356,28 @@ def eval(
353356
splits=[split] if isinstance(split, str) else split,
354357
)
355358

359+
# This ensures compatibility with Streamlit and callback functions
360+
if progress_callback is not None and metrics_callback is not None:
361+
num_workers = 0
362+
else:
363+
num_workers = self.model_cfg.get("num_workers")
364+
356365
dataloader = DataLoader(
357366
dataset,
358367
batch_size=self.model_cfg.get("batch_size", 1),
359-
num_workers=self.model_cfg.get("num_workers", 1),
360-
collate_fn=lambda x: tuple(zip(*x)), # handles variable-size targets
368+
num_workers=num_workers,
369+
collate_fn=lambda batch: tuple(zip(*batch)), # handles variable-size targets
361370
)
362371

363372
# Get iou_threshold from model config, default to 0.5 if not present
364373
iou_threshold = self.model_cfg.get("iou_threshold", 0.5)
365374

375+
# Get evaluation_step from model config, default to None (no intermediate updates)
376+
evaluation_step = self.model_cfg.get("evaluation_step", None)
377+
# If evaluation_step is 0, treat as None (disabled)
378+
if evaluation_step == 0:
379+
evaluation_step = None
380+
366381
# Init metrics
367382
metrics_factory = um.DetectionMetricsFactory(
368383
iou_threshold=iou_threshold, num_classes=self.n_classes
@@ -472,6 +487,14 @@ def eval(
472487
# Call progress callback if provided
473488
if progress_callback is not None:
474489
progress_callback(processed_samples, total_samples)
490+
491+
# Call metrics callback if provided and evaluation_step is reached
492+
if (metrics_callback is not None and
493+
evaluation_step is not None and
494+
processed_samples % evaluation_step == 0):
495+
# Get intermediate metrics
496+
intermediate_metrics = metrics_factory.get_metrics_dataframe(self.ontology)
497+
metrics_callback(intermediate_metrics, processed_samples, total_samples)
475498

476499
# Return both the DataFrame and the metrics factory for access to precision-recall curves
477500
return {

tabs/evaluator.py

Lines changed: 91 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def evaluator_tab():
113113
progress_bar = st.progress(0)
114114
status_text = st.empty()
115115

116+
# Create placeholders for intermediate metrics that will be updated in place
117+
intermediate_metrics_placeholder = st.empty()
118+
intermediate_table_placeholder = st.empty()
119+
116120
def progress_callback(processed, total):
117121
"""Progress callback for Streamlit UI"""
118122
try:
@@ -122,6 +126,42 @@ def progress_callback(processed, total):
122126
except Exception as e:
123127
st.error(f"Progress callback error: {e}")
124128

129+
def metrics_callback(metrics_df, processed, total):
130+
"""Metrics callback for intermediate results display"""
131+
try:
132+
# Update the metrics placeholder with current summary metrics
133+
if 'mean' in metrics_df.columns:
134+
mean_metrics = metrics_df['mean']
135+
136+
with intermediate_metrics_placeholder.container():
137+
st.markdown(f"#### 📊 Intermediate Results (after {processed} images)")
138+
139+
col1, col2, col3 = st.columns(3)
140+
with col1:
141+
st.metric("mAP", f"{mean_metrics.get('AP', 0):.3f}")
142+
with col2:
143+
st.metric("Mean Precision", f"{mean_metrics.get('Precision', 0):.3f}")
144+
with col3:
145+
st.metric("Mean Recall", f"{mean_metrics.get('Recall', 0):.3f}")
146+
147+
# Update the table placeholder with current per-class results
148+
per_class_results = metrics_df.drop(columns=['mean']) if 'mean' in metrics_df.columns else metrics_df
149+
per_class_results = per_class_results.drop(['AUC-PR', 'mAP@[0.5:0.95]'], errors='ignore')
150+
151+
# Round for display
152+
display_df = per_class_results.copy()
153+
numeric_columns = display_df.select_dtypes(include=['float64', 'int64']).columns
154+
for col in numeric_columns:
155+
if col in display_df.columns:
156+
display_df[col] = display_df[col].round(3)
157+
158+
with intermediate_table_placeholder.container():
159+
st.markdown("#### Per-Class Metrics (Intermediate)")
160+
st.dataframe(display_df, use_container_width=True)
161+
162+
except Exception as e:
163+
st.error(f"Metrics callback error: {e}")
164+
125165
# Run evaluation with progress tracking
126166
# Use full dataset for evaluation
127167

@@ -131,7 +171,7 @@ def progress_callback(processed, total):
131171
# Create a shallow copy of the dataset object with only first 10 rows
132172
import copy
133173
dataset_subset = copy.copy(dataset)
134-
dataset_subset.dataset = dataset.dataset.iloc[:10].copy()
174+
dataset_subset.dataset = dataset.dataset.iloc[:100].copy()
135175
else:
136176
st.warning("Dataset object does not have a 'dataset' attribute; using as is.")
137177
dataset_subset = dataset
@@ -142,17 +182,20 @@ def progress_callback(processed, total):
142182
ontology_translation=ontology_translation_path,
143183
predictions_outdir=predictions_outdir,
144184
results_per_sample=save_predictions,
145-
progress_callback=progress_callback
185+
progress_callback=progress_callback,
186+
metrics_callback=metrics_callback
146187
)
147188
except Exception as e:
148189
st.error(f"Error in model.eval(): {e}")
149190
return
150191

151192
# Results ready
152193

153-
# Clear progress elements
194+
# Clear progress elements and intermediate results
154195
progress_bar.empty()
155196
status_text.empty()
197+
intermediate_metrics_placeholder.empty()
198+
intermediate_table_placeholder.empty()
156199

157200
# Store results in session state
158201
st.session_state['evaluation_results'] = results
@@ -201,33 +244,41 @@ def display_evaluation_results(results):
201244
if 'mean' in metrics_df.columns:
202245
mean_metrics = metrics_df['mean']
203246

204-
col1, col2, col3, col4 = st.columns(4)
247+
col1, col2, col3, col4, col5 = st.columns(5)
205248
with col1:
206249
st.metric("mAP", f"{mean_metrics.get('AP', 0):.3f}")
207250
with col2:
208251
st.metric("Mean Precision", f"{mean_metrics.get('Precision', 0):.3f}")
209252
with col3:
210253
st.metric("Mean Recall", f"{mean_metrics.get('Recall', 0):.3f}")
211254
with col4:
212-
total_detections = mean_metrics.get('TP', 0) + mean_metrics.get('FP', 0)
213-
st.metric("Total Detections", f"{total_detections:.0f}")
214-
215-
# Add COCO mAP and AUC-PR in a second row
216-
col5, col6, col7, col8 = st.columns(4)
217-
with col5:
218255
coco_map = mean_metrics.get('mAP@[0.5:0.95]', 0)
219256
st.metric("mAP@[0.5:0.95]", f"{coco_map:.3f}")
220-
with col6:
257+
with col5:
221258
auc_pr = mean_metrics.get('AUC-PR', 0)
222259
st.metric("AUC-PR", f"{auc_pr:.3f}")
223-
with col7:
224-
# Empty column for spacing
225-
st.empty()
226-
with col8:
227-
# Empty column for spacing
228-
st.empty()
229260

230-
# Display Precision-Recall Curve
261+
# Display per-class metrics first
262+
st.markdown("#### Per-Class Metrics")
263+
264+
# Filter out the 'mean' column for per-class display
265+
per_class_results = metrics_df.drop(columns=['mean']) if 'mean' in metrics_df.columns else metrics_df
266+
267+
# Remove overall metrics rows (AUC-PR and mAP@[0.5:0.95]) from per-class display
268+
per_class_results = per_class_results.drop(['AUC-PR', 'mAP@[0.5:0.95]'], errors='ignore')
269+
270+
# Create a more readable display
271+
display_df = per_class_results.copy()
272+
273+
# Round numeric columns for better display
274+
numeric_columns = display_df.select_dtypes(include=['float64', 'int64']).columns
275+
for col in numeric_columns:
276+
if col in display_df.columns:
277+
display_df[col] = display_df[col].round(3)
278+
279+
st.dataframe(display_df, use_container_width=True)
280+
281+
# Now display Precision-Recall Curve
231282
if metrics_factory is not None:
232283
st.markdown("#### Precision-Recall Curve")
233284

@@ -287,38 +338,38 @@ def display_evaluation_results(results):
287338
st.error(f"Error plotting precision-recall curve: {e}")
288339
st.info("Precision-recall curve data not available.")
289340

290-
# Display per-class metrics
291-
st.markdown("#### Per-Class Metrics")
292-
293-
# Filter out the 'mean' column for per-class display
294-
per_class_results = metrics_df.drop(columns=['mean']) if 'mean' in metrics_df.columns else metrics_df
295-
296-
# Remove overall metrics rows (AUC-PR and mAP@[0.5:0.95]) from per-class display
297-
per_class_results = per_class_results.drop(['AUC-PR', 'mAP@[0.5:0.95]'], errors='ignore')
298-
299-
# Create a more readable display
300-
display_df = per_class_results.copy()
301-
302-
# Round numeric columns for better display
303-
numeric_columns = display_df.select_dtypes(include=['float64', 'int64']).columns
304-
for col in numeric_columns:
305-
if col in display_df.columns:
306-
display_df[col] = display_df[col].round(3)
307-
308-
st.dataframe(display_df, use_container_width=True)
309-
310341
# Download results
311342
st.markdown("#### Download Results")
312343

313344
# Convert to CSV for download
314345
csv = metrics_df.to_csv(index=True)
315346
st.download_button(
316-
label="📥 Download Results as CSV",
347+
label="📥 Download per class metrics",
317348
data=csv,
318349
file_name="evaluation_results.csv",
319350
mime="text/csv"
320351
)
321-
352+
try:
353+
curve_data = metrics_factory.get_overall_precision_recall_curve() if metrics_factory is not None else None
354+
if curve_data is not None:
355+
import io
356+
import pandas as pd
357+
pr_points_df = pd.DataFrame({
358+
"recall": curve_data["recall"],
359+
"precision": curve_data["precision"]
360+
})
361+
pr_csv = pr_points_df.to_csv(index=False)
362+
st.download_button(
363+
label="📈 Download precision-recall points",
364+
data=pr_csv,
365+
file_name="precision_recall_points.csv",
366+
mime="text/csv"
367+
)
368+
else:
369+
st.write("No precision-recall data available.")
370+
except Exception as e:
371+
st.write(f"Error preparing precision-recall points: {e}")
372+
322373
# Show detailed statistics
323374
with st.expander("📊 Detailed Statistics"):
324375
st.markdown("**Results Shape:**")

0 commit comments

Comments
 (0)