Skip to content

Commit b548c85

Browse files
committed
Simplified code with button wrapper
1 parent 6d83c64 commit b548c85

File tree

1 file changed

+108
-63
lines changed

1 file changed

+108
-63
lines changed

app/streamlit_app.py

Lines changed: 108 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import time
33
import io
44
import gzip
5-
import random
65

76
import streamlit as st
87
import pandas as pd
@@ -82,6 +81,45 @@
8281
S_RESULTS = 'stored_results'
8382

8483

84+
def spinner_button_trigger(
85+
trigger_key,
86+
running_text,
87+
*args,
88+
**kwargs):
89+
if trigger_key not in st.session_state:
90+
st.session_state[trigger_key] = dict(
91+
trigger=False,
92+
args=None,
93+
kwargs=None)
94+
def _trigger(*args, **kwargs):
95+
st.session_state[trigger_key].update(dict(
96+
trigger=True,
97+
args=args,
98+
kwargs=kwargs))
99+
cont = st.empty()
100+
cb = kwargs.get('on_click', None)
101+
cb_args = kwargs.get('args', None)
102+
cb_kwargs = kwargs.get('kwargs', None)
103+
kw = {}
104+
kw.update(kwargs)
105+
kw['on_click'] = _trigger
106+
kw['args'] = cb_args
107+
kw['kwargs'] = cb_kwargs
108+
if st.session_state[trigger_key]['trigger']:
109+
with cont:
110+
with st.spinner(running_text):
111+
cb(*st.session_state[trigger_key]['args'], **st.session_state[trigger_key]['kwargs'])
112+
st.session_state[trigger_key] = dict(
113+
trigger=False,
114+
args=None,
115+
kwargs=None)
116+
with cont:
117+
butt = st.button(
118+
*args,
119+
**kw)
120+
return butt
121+
122+
85123
class Results:
86124

87125
def __init__(self):
@@ -91,8 +129,8 @@ def __init__(self):
91129
self.df_summary = pd.DataFrame()
92130
self.mapper_graph = nx.Graph()
93131
self.mapper_plot = None
94-
self.fig = go.Figure()
95-
self.update_fig = False
132+
self.mapper_fig = go.Figure()
133+
self.auto_rendering = None
96134

97135
def set_df(self, X, y):
98136
self.df_X = fix_data(X)
@@ -101,8 +139,8 @@ def set_df(self, X, y):
101139
self.df_summary = get_data_summary(self.df_X, self.df_y)
102140
self.mapper_graph = nx.Graph()
103141
self.mapper_plot = None
104-
self.fig = go.Figure()
105-
self.update_fig = False
142+
self.mapper_fig = go.Figure()
143+
self.auto_rendering = None
106144

107145
def set_mapper(self, mapper_graph):
108146
self.mapper_graph = mapper_graph
@@ -113,16 +151,16 @@ def set_mapper(self, mapper_graph):
113151
width=450,
114152
colors=self.X,
115153
seed=VD_SEED)
116-
self.fig = go.Figure()
154+
self.mapper_fig = go.Figure()
117155
nodes_num = mapper_graph.number_of_nodes()
118156
if nodes_num <= MAX_NODES:
119-
self.update_fig = True
157+
self.auto_rendering = True
120158
else:
121-
self.update_fig = False
159+
self.auto_rendering = False
122160

123-
def set_fig(self, fig):
124-
self.fig = fig
125-
self.update_fig = False
161+
def set_mapper_fig(self, mapper_fig):
162+
self.mapper_fig = mapper_fig
163+
self.auto_rendering = None
126164

127165

128166
def lp_metric(p):
@@ -317,6 +355,7 @@ def _load_data(data_source):
317355
st.toast(err, icon='🚨')
318356
df_X, df_y = fix_data(X), fix_data(y)
319357
st.session_state[S_RESULTS].set_df(df_X, df_y)
358+
st.toast('Successfully Loaded Data', icon='✅')
320359

321360

322361
def data_section():
@@ -333,18 +372,15 @@ def data_section():
333372
data_source = st.text_input('Name', placeholder='Name', help=f'Search on [OpenML]({OPENML_URL})')
334373
elif data_source_type == 'CSV':
335374
data_source = st.file_uploader('Upload')
336-
with col_2:
337-
load_cont = st.empty()
338375
with col_1:
339376
st.markdown('####')
340-
with load_cont:
341-
st.button(
377+
with col_2:
378+
spinner_button_trigger(
379+
'load_trigger',
380+
'⏳ Loading Data...',
342381
'📦 Load',
343382
use_container_width=True,
344-
on_click=wrap_callback(
345-
load_cont,
346-
'⏳ Loading Data...',
347-
_load_data),
383+
on_click=_load_data,
348384
args=(data_source,))
349385
df_X = st.session_state[S_RESULTS].df_X
350386
df_y = st.session_state[S_RESULTS].df_y
@@ -355,20 +391,6 @@ def data_section():
355391
st.caption(get_data_caption(df_X, df_y))
356392

357393

358-
def _run(X, lens, cover, clustering):
359-
mapper_algo = MapperAlgorithm(
360-
cover=cover,
361-
clustering=FailSafeClustering(
362-
clustering=clustering,
363-
verbose=False))
364-
mapper_graph = mapper_algo.fit_transform(X, lens)
365-
st.session_state[S_RESULTS].set_mapper(mapper_graph)
366-
st.toast('Succesfully computed!', icon='✅')
367-
nodes_num = mapper_graph.number_of_nodes()
368-
if nodes_num > MAX_NODES:
369-
st.toast('Skipping rendering (graph too large)', icon='⚠️')
370-
371-
372394
def settings_tab(X):
373395
tab_0, tab_1, tab_2 = st.tabs(['🔎 Lens', '🌐 Cover', '🧮 Clustering'])
374396
h = 300
@@ -427,6 +449,20 @@ def settings_tab(X):
427449
return lens, cover, clustering
428450

429451

452+
def _update_mapper(X, lens, cover, clustering):
453+
mapper_algo = MapperAlgorithm(
454+
cover=cover,
455+
clustering=FailSafeClustering(
456+
clustering=clustering,
457+
verbose=False))
458+
mapper_graph = mapper_algo.fit_transform(X, lens)
459+
st.session_state[S_RESULTS].set_mapper(mapper_graph)
460+
st.toast('Successfully Computed Mapper', icon='✅')
461+
auto_rendering = st.session_state[S_RESULTS].auto_rendering
462+
if auto_rendering is False:
463+
st.toast('Automatic Rendering Disabled: Graph Too Large', icon='⚠️')
464+
465+
430466
def settings_section():
431467
st.subheader('⚙️ Mapper Settings')
432468
X = st.session_state[S_RESULTS].X
@@ -435,19 +471,17 @@ def settings_section():
435471
with col_0:
436472
lens, cover, clustering = settings_tab(X)
437473
with col_2:
438-
run_cont = st.empty()
439-
440-
with run_cont:
441-
st.button(
474+
spinner_button_trigger(
475+
'run_trigger',
476+
'⏳ Computing...',
442477
'🚀 Run Mapper',
443478
use_container_width=True,
444479
disabled=X.size == 0,
445-
on_click=wrap_callback(
446-
run_cont,
447-
'⏳ Computing...',
448-
_run),
480+
on_click=_update_mapper,
449481
args=(X, lens, cover, clustering,))
482+
450483
mapper_graph = st.session_state[S_RESULTS].mapper_graph
484+
451485
with col_1:
452486
with st.container(border=True):
453487
fig_hist = graph_histogram(mapper_graph)
@@ -460,12 +494,17 @@ def settings_section():
460494
graph_download_button(mapper_graph)
461495

462496

463-
def _update(mapper_plot, seed, colors):
464-
mapper_plot.update(colors=colors, seed=seed)
497+
def _update_fig(seed, colors):
498+
mapper_plot = st.session_state[S_RESULTS].mapper_plot
499+
if mapper_plot is None:
500+
return
501+
mapper_plot.update(
502+
colors=colors,
503+
seed=seed)
465504
mapper_fig = mapper_plot.plot()
466505
mapper_fig.update_layout(uirevision='constant')
467-
st.session_state[S_RESULTS].set_fig(mapper_fig)
468-
st.toast('Succesfully rendered!', icon='✅')
506+
st.session_state[S_RESULTS].set_mapper_fig(mapper_fig)
507+
st.toast('Successfully Rendered Graph', icon='✅')
469508

470509

471510
def rendering_section():
@@ -475,9 +514,13 @@ def rendering_section():
475514
df_y = st.session_state[S_RESULTS].df_y
476515
X = st.session_state[S_RESULTS].X
477516
mapper_plot = st.session_state[S_RESULTS].mapper_plot
478-
col_4, _ = st.columns([2, 4])
479-
with col_4:
480-
popover = st.popover('🎨 Options', use_container_width=True)
517+
col_0, col_1 = st.columns([2, 4])
518+
with col_1:
519+
popover = st.popover(
520+
'🎨 Options',
521+
use_container_width=True,
522+
disabled=mapper_plot is None)
523+
481524
with popover:
482525
seed = st.number_input('Seed', value=VD_SEED)
483526
data_edit = st.data_editor(
@@ -504,31 +547,33 @@ def rendering_section():
504547
selected = pd.concat([df_Xy[c] for c in color_features], axis=1)
505548
if not selected.empty:
506549
colors = selected.to_numpy()
507-
update_cont = st.empty()
508-
with update_cont:
509-
st.button(
550+
551+
auto_rendering = st.session_state[S_RESULTS].auto_rendering
552+
if auto_rendering:
553+
_update_fig(seed, colors)
554+
555+
with col_0:
556+
spinner_button_trigger(
557+
'update_trigger',
558+
'⏳ Rendering...',
510559
'🌊 Update',
511560
use_container_width=True,
512561
disabled=mapper_plot is None,
513-
on_click=wrap_callback(
514-
update_cont,
515-
'⏳ Rendering...',
516-
_update),
517-
args=(mapper_plot, seed, colors))
518-
if st.session_state[S_RESULTS].update_fig:
519-
wrap_callback(
520-
update_cont,
521-
'⏳ Rendering...',
522-
_update)(mapper_plot, seed, colors)
523-
st.session_state[S_RESULTS].update_fig = False
524-
mapper_fig = st.session_state[S_RESULTS].fig
562+
on_click=_update_fig,
563+
args=(seed, colors))
564+
565+
mapper_fig = st.session_state[S_RESULTS].mapper_fig
525566
with st.container(border=True):
526567
st.plotly_chart(
527568
mapper_fig,
528569
height=450,
529570
use_container_width=True)
530571

531572

573+
574+
575+
576+
532577
def main():
533578
set_headings()
534579
if S_RESULTS not in st.session_state:

0 commit comments

Comments
 (0)