77import streamlit as st
88import pandas as pd
99import numpy as np
10-
10+ import plotly . express as px
1111
1212from networkx .readwrite .json_graph import adjacency_data
1313
1414from sklearn .datasets import fetch_openml , load_digits , load_iris
1515from sklearn .cluster import AgglomerativeClustering
1616from sklearn .decomposition import PCA
1717
18- from tdamapper .core import MapperAlgorithm
18+ from tdamapper .core import MapperAlgorithm , ATTR_SIZE
1919from tdamapper .cover import CubicalCover , BallCover , TrivialCover
2020from tdamapper .clustering import TrivialClustering , FailSafeClustering
2121from tdamapper .plot import MapperLayoutInteractive
@@ -276,12 +276,37 @@ def add_download_graph():
276276 f'📥 Download Graph' ,
277277 data = get_gzip_bytes (mapper_json ),
278278 disabled = mapper_graph is None ,
279+ use_container_width = True ,
279280 file_name = f'mapper_graph_{ int (time .time ())} .json.gzip' )
281+
282+
283+ def add_graph_caption ():
284+ mapper_graph = st .session_state [S_RESULTS ].mapper_graph
280285 if mapper_graph is None :
281286 return
287+ import networkx as nx
288+ ccs = nx .connected_components (mapper_graph )
289+ size = nx .get_node_attributes (mapper_graph , ATTR_SIZE )
290+ ff = {}
291+ for cc in ccs :
292+ len_cc = len (cc )
293+ for u in cc :
294+ ff [u ] = 1.0 / len_cc
295+ df_ccs = pd .DataFrame ({
296+ 'kpi' : list (ff .values ())
297+ })
298+ fig = px .histogram (df_ccs , x = 'kpi' , height = 250 , nbins = 10 )
299+ fig .update_layout (
300+ margin = dict (l = 0 , r = 0 , t = 0 , b = 0 , pad = 5 ),
301+ xaxis_visible = True ,
302+ xaxis_title_standoff = 0 ,
303+ xaxis_title = 'kpi = 1 / connected component size' ,
304+ yaxis_title_standoff = 10 ,
305+ yaxis_visible = True )
282306 nodes_num = mapper_graph .number_of_nodes ()
283307 edges_num = mapper_graph .number_of_edges ()
284308 st .caption (f'{ nodes_num } nodes, { edges_num } edges' )
309+ st .plotly_chart (fig , use_container_width = True )
285310
286311
287312def add_data_source_csv ():
@@ -323,11 +348,13 @@ def add_data_source():
323348
324349
325350def add_lens_settings ():
326- lens_type = st .selectbox (
327- 'Type' ,
328- options = [V_LENS_IDENTITY , V_LENS_PCA ],
329- key = K_LENS_TYPE )
330- if lens_type == V_LENS_PCA :
351+ opts = st .container (height = 200 , border = False )
352+ with opts :
353+ lens_type = st .selectbox (
354+ 'Type' ,
355+ options = [V_LENS_IDENTITY , V_LENS_PCA ],
356+ key = K_LENS_TYPE )
357+ if lens_type == V_LENS_PCA :
331358 st .number_input (
332359 'Components' ,
333360 value = 1 ,
@@ -336,65 +363,81 @@ def add_lens_settings():
336363
337364
338365def add_cover_settings ():
339- cover_type = st .selectbox (
340- 'Type' ,
341- options = [V_COVER_BALL , V_COVER_CUBICAL , V_COVER_TRIVIAL ],
342- key = K_COVER_TYPE )
343- if cover_type == V_COVER_BALL :
344- st .number_input (
345- 'Radius' ,
346- value = 100.0 ,
347- min_value = 0.0 ,
348- key = K_COVER_BALL_RADIUS )
349- st .number_input (
350- 'Lp Metric' ,
351- value = 2 ,
352- min_value = 1 ,
353- key = K_COVER_BALL_METRIC_P )
354- elif cover_type == V_COVER_CUBICAL :
355- st .number_input (
356- 'Intervals' ,
357- value = 2 ,
358- min_value = 0 ,
359- key = K_COVER_CUBICAL_N )
360- st .number_input (
361- 'Overlap Fraction' ,
362- value = 0.10 ,
363- min_value = 0.0 ,
364- max_value = 1.0 ,
365- key = K_COVER_CUBICAL_OVERLAP )
366+ opts = st .container (height = 200 , border = False )
367+ with opts :
368+ cover_type = st .selectbox (
369+ 'Type' ,
370+ options = [V_COVER_BALL , V_COVER_CUBICAL , V_COVER_TRIVIAL ],
371+ key = K_COVER_TYPE )
372+ if cover_type == V_COVER_BALL :
373+ st .number_input (
374+ 'Radius' ,
375+ value = 100.0 ,
376+ min_value = 0.0 ,
377+ key = K_COVER_BALL_RADIUS )
378+ st .number_input (
379+ 'Lp Metric' ,
380+ value = 2 ,
381+ min_value = 1 ,
382+ key = K_COVER_BALL_METRIC_P )
383+ elif cover_type == V_COVER_CUBICAL :
384+ st .number_input (
385+ 'Intervals' ,
386+ value = 2 ,
387+ min_value = 0 ,
388+ key = K_COVER_CUBICAL_N )
389+ st .number_input (
390+ 'Overlap Fraction' ,
391+ value = 0.10 ,
392+ min_value = 0.0 ,
393+ max_value = 1.0 ,
394+ key = K_COVER_CUBICAL_OVERLAP )
366395
367396
368397def add_clustering_settings ():
369- clustering_type = st .selectbox (
370- 'Type' ,
371- options = [V_CLUSTERING_TRIVIAL , V_CLUSTERING_AGGLOMERATIVE ],
372- key = K_CLUSTERING_TYPE )
373- if clustering_type == V_CLUSTERING_AGGLOMERATIVE :
374- st .number_input (
375- 'Clusters' ,
376- value = 2 ,
377- min_value = 1 ,
378- key = K_CLUSTERING_AGGLOMERATIVE_N )
398+ opts = st .container (height = 200 , border = False )
399+ with opts :
400+ clustering_type = st .selectbox (
401+ 'Type' ,
402+ options = [V_CLUSTERING_TRIVIAL , V_CLUSTERING_AGGLOMERATIVE ],
403+ key = K_CLUSTERING_TYPE )
404+ if clustering_type == V_CLUSTERING_AGGLOMERATIVE :
405+ st .number_input (
406+ 'Clusters' ,
407+ value = 2 ,
408+ min_value = 1 ,
409+ key = K_CLUSTERING_AGGLOMERATIVE_N )
379410
380411
381412def add_mapper_settings ():
382413 df_X = st .session_state [S_RESULTS ].df_X
383414 st .markdown ('### ⚙️ Settings' )
384- with st .expander ('Lens' ):
415+ col_0 , col_1 = st .columns ([2 , 4 ])
416+ with col_0 :
417+ tab_0 , tab_1 , tab_2 = st .tabs ([
418+ '🔎 Lens' ,
419+ '🌐 Cover' ,
420+ '🧮 Clustering' ])
421+ with tab_0 :
385422 add_lens_settings ()
386- with st . expander ( '🌐 Cover' ) :
423+ with tab_1 :
387424 add_cover_settings ()
388- with st . expander ( '🧮 Clustering' ) :
425+ with tab_2 :
389426 add_clustering_settings ()
390- run = st .button (
391- '🚀 Run Mapper' ,
392- type = 'primary' ,
393- disabled = df_X is None )
427+ col1_0 , col1_1 , _ = st .columns ([2 , 2 , 2 ])
428+ with col1_0 :
429+ run = st .button (
430+ '🚀 Run Mapper' ,
431+ type = 'primary' ,
432+ use_container_width = True ,
433+ disabled = df_X is None )
394434 if run :
395435 with st .spinner ('⏳ Computing Mapper...' ):
396436 compute_mapper ()
397- add_download_graph ()
437+ with col_1 :
438+ add_graph_caption ()
439+ with col1_1 :
440+ add_download_graph ()
398441
399442
400443def get_lens_func ():
@@ -577,17 +620,20 @@ def add_graph_plot():
577620
578621def add_data ():
579622 st .markdown ('### 📊 Data' )
623+ col_0 , col_1 = st .columns ([2 , 4 ])
624+ with col_0 :
625+ add_data_source ()
580626 df_X = st .session_state [S_RESULTS ].df_X
581627 df_y = st .session_state [S_RESULTS ].df_y
582- add_data_source ()
583- df_X = st .session_state [S_RESULTS ].df_X
584- df_y = st .session_state [S_RESULTS ].df_y
585- if df_X is None :
586- return
587- cap = data_caption (df_X , df_y )
588- with st .expander (cap ):
628+ df_all = pd .DataFrame ()
629+ cap = 'empty dataset'
630+ if df_X is not None :
589631 df_all = pd .concat ([get_sample (df_y , frac = 1.0 ), get_sample (df_X , frac = 1.0 )], axis = 1 )
590- st .dataframe (df_all , height = 300 )
632+ cap = data_caption (df_X , df_y )
633+ with col_1 :
634+ st .markdown ('####' )
635+ st .dataframe (df_all , height = 250 , use_container_width = True )
636+ st .caption (cap )
591637
592638
593639def add_rendering ():
@@ -633,7 +679,6 @@ def main():
633679 st .markdown (APP_DESC )
634680 if S_RESULTS not in st .session_state :
635681 st .session_state [S_RESULTS ] = Results ()
636- st .markdown ('#' )
637682 add_data ()
638683 st .markdown ('#' )
639684 add_mapper_settings ()
0 commit comments