Skip to content

Commit e903fa0

Browse files
committed
Simplified layout. Fixed rerun bugs
1 parent 4ccc93b commit e903fa0

File tree

1 file changed

+106
-61
lines changed

1 file changed

+106
-61
lines changed

app/streamlit_app.py

Lines changed: 106 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
import streamlit as st
88
import pandas as pd
99
import numpy as np
10-
10+
import plotly.express as px
1111

1212
from networkx.readwrite.json_graph import adjacency_data
1313

1414
from sklearn.datasets import fetch_openml, load_digits, load_iris
1515
from sklearn.cluster import AgglomerativeClustering
1616
from sklearn.decomposition import PCA
1717

18-
from tdamapper.core import MapperAlgorithm
18+
from tdamapper.core import MapperAlgorithm, ATTR_SIZE
1919
from tdamapper.cover import CubicalCover, BallCover, TrivialCover
2020
from tdamapper.clustering import TrivialClustering, FailSafeClustering
2121
from tdamapper.plot import MapperLayoutInteractive
@@ -276,12 +276,37 @@ def add_download_graph():
276276
f'📥 Download Graph',
277277
data=get_gzip_bytes(mapper_json),
278278
disabled=mapper_graph is None,
279+
use_container_width=True,
279280
file_name=f'mapper_graph_{int(time.time())}.json.gzip')
281+
282+
283+
def add_graph_caption():
284+
mapper_graph = st.session_state[S_RESULTS].mapper_graph
280285
if mapper_graph is None:
281286
return
287+
import networkx as nx
288+
ccs = nx.connected_components(mapper_graph)
289+
size = nx.get_node_attributes(mapper_graph, ATTR_SIZE)
290+
ff = {}
291+
for cc in ccs:
292+
len_cc = len(cc)
293+
for u in cc:
294+
ff[u] = 1.0 / len_cc
295+
df_ccs = pd.DataFrame({
296+
'kpi': list(ff.values())
297+
})
298+
fig = px.histogram(df_ccs, x='kpi', height=250, nbins=10)
299+
fig.update_layout(
300+
margin=dict(l=0, r=0, t=0, b=0, pad=5),
301+
xaxis_visible=True,
302+
xaxis_title_standoff=0,
303+
xaxis_title='kpi = 1 / connected component size',
304+
yaxis_title_standoff=10,
305+
yaxis_visible=True)
282306
nodes_num = mapper_graph.number_of_nodes()
283307
edges_num = mapper_graph.number_of_edges()
284308
st.caption(f'{nodes_num} nodes, {edges_num} edges')
309+
st.plotly_chart(fig, use_container_width=True)
285310

286311

287312
def add_data_source_csv():
@@ -323,11 +348,13 @@ def add_data_source():
323348

324349

325350
def add_lens_settings():
326-
lens_type = st.selectbox(
327-
'Type',
328-
options=[V_LENS_IDENTITY, V_LENS_PCA],
329-
key=K_LENS_TYPE)
330-
if lens_type == V_LENS_PCA:
351+
opts = st.container(height=200, border=False)
352+
with opts:
353+
lens_type = st.selectbox(
354+
'Type',
355+
options=[V_LENS_IDENTITY, V_LENS_PCA],
356+
key=K_LENS_TYPE)
357+
if lens_type == V_LENS_PCA:
331358
st.number_input(
332359
'Components',
333360
value=1,
@@ -336,65 +363,81 @@ def add_lens_settings():
336363

337364

338365
def add_cover_settings():
339-
cover_type = st.selectbox(
340-
'Type',
341-
options=[V_COVER_BALL, V_COVER_CUBICAL, V_COVER_TRIVIAL],
342-
key=K_COVER_TYPE)
343-
if cover_type == V_COVER_BALL:
344-
st.number_input(
345-
'Radius',
346-
value=100.0,
347-
min_value=0.0,
348-
key=K_COVER_BALL_RADIUS)
349-
st.number_input(
350-
'Lp Metric',
351-
value=2,
352-
min_value=1,
353-
key=K_COVER_BALL_METRIC_P)
354-
elif cover_type == V_COVER_CUBICAL:
355-
st.number_input(
356-
'Intervals',
357-
value=2,
358-
min_value=0,
359-
key=K_COVER_CUBICAL_N)
360-
st.number_input(
361-
'Overlap Fraction',
362-
value=0.10,
363-
min_value=0.0,
364-
max_value=1.0,
365-
key=K_COVER_CUBICAL_OVERLAP)
366+
opts = st.container(height=200, border=False)
367+
with opts:
368+
cover_type = st.selectbox(
369+
'Type',
370+
options=[V_COVER_BALL, V_COVER_CUBICAL, V_COVER_TRIVIAL],
371+
key=K_COVER_TYPE)
372+
if cover_type == V_COVER_BALL:
373+
st.number_input(
374+
'Radius',
375+
value=100.0,
376+
min_value=0.0,
377+
key=K_COVER_BALL_RADIUS)
378+
st.number_input(
379+
'Lp Metric',
380+
value=2,
381+
min_value=1,
382+
key=K_COVER_BALL_METRIC_P)
383+
elif cover_type == V_COVER_CUBICAL:
384+
st.number_input(
385+
'Intervals',
386+
value=2,
387+
min_value=0,
388+
key=K_COVER_CUBICAL_N)
389+
st.number_input(
390+
'Overlap Fraction',
391+
value=0.10,
392+
min_value=0.0,
393+
max_value=1.0,
394+
key=K_COVER_CUBICAL_OVERLAP)
366395

367396

368397
def add_clustering_settings():
369-
clustering_type = st.selectbox(
370-
'Type',
371-
options=[V_CLUSTERING_TRIVIAL, V_CLUSTERING_AGGLOMERATIVE],
372-
key=K_CLUSTERING_TYPE)
373-
if clustering_type == V_CLUSTERING_AGGLOMERATIVE:
374-
st.number_input(
375-
'Clusters',
376-
value=2,
377-
min_value=1,
378-
key=K_CLUSTERING_AGGLOMERATIVE_N)
398+
opts = st.container(height=200, border=False)
399+
with opts:
400+
clustering_type = st.selectbox(
401+
'Type',
402+
options=[V_CLUSTERING_TRIVIAL, V_CLUSTERING_AGGLOMERATIVE],
403+
key=K_CLUSTERING_TYPE)
404+
if clustering_type == V_CLUSTERING_AGGLOMERATIVE:
405+
st.number_input(
406+
'Clusters',
407+
value=2,
408+
min_value=1,
409+
key=K_CLUSTERING_AGGLOMERATIVE_N)
379410

380411

381412
def add_mapper_settings():
382413
df_X = st.session_state[S_RESULTS].df_X
383414
st.markdown('### ⚙️ Settings')
384-
with st.expander('Lens'):
415+
col_0, col_1 = st.columns([2, 4])
416+
with col_0:
417+
tab_0, tab_1, tab_2 = st.tabs([
418+
'🔎 Lens',
419+
'🌐 Cover',
420+
'🧮 Clustering'])
421+
with tab_0:
385422
add_lens_settings()
386-
with st.expander('🌐 Cover'):
423+
with tab_1:
387424
add_cover_settings()
388-
with st.expander('🧮 Clustering'):
425+
with tab_2:
389426
add_clustering_settings()
390-
run = st.button(
391-
'🚀 Run Mapper',
392-
type='primary',
393-
disabled=df_X is None)
427+
col1_0, col1_1, _ = st.columns([2, 2, 2])
428+
with col1_0:
429+
run = st.button(
430+
'🚀 Run Mapper',
431+
type='primary',
432+
use_container_width=True,
433+
disabled=df_X is None)
394434
if run:
395435
with st.spinner('⏳ Computing Mapper...'):
396436
compute_mapper()
397-
add_download_graph()
437+
with col_1:
438+
add_graph_caption()
439+
with col1_1:
440+
add_download_graph()
398441

399442

400443
def get_lens_func():
@@ -577,17 +620,20 @@ def add_graph_plot():
577620

578621
def add_data():
579622
st.markdown('### 📊 Data')
623+
col_0, col_1 = st.columns([2, 4])
624+
with col_0:
625+
add_data_source()
580626
df_X = st.session_state[S_RESULTS].df_X
581627
df_y = st.session_state[S_RESULTS].df_y
582-
add_data_source()
583-
df_X = st.session_state[S_RESULTS].df_X
584-
df_y = st.session_state[S_RESULTS].df_y
585-
if df_X is None:
586-
return
587-
cap = data_caption(df_X, df_y)
588-
with st.expander(cap):
628+
df_all = pd.DataFrame()
629+
cap = 'empty dataset'
630+
if df_X is not None:
589631
df_all = pd.concat([get_sample(df_y, frac=1.0), get_sample(df_X, frac=1.0)], axis=1)
590-
st.dataframe(df_all, height=300)
632+
cap = data_caption(df_X, df_y)
633+
with col_1:
634+
st.markdown('####')
635+
st.dataframe(df_all, height=250, use_container_width=True)
636+
st.caption(cap)
591637

592638

593639
def add_rendering():
@@ -633,7 +679,6 @@ def main():
633679
st.markdown(APP_DESC)
634680
if S_RESULTS not in st.session_state:
635681
st.session_state[S_RESULTS] = Results()
636-
st.markdown('#')
637682
add_data()
638683
st.markdown('#')
639684
add_mapper_settings()

0 commit comments

Comments
 (0)