|
1 | 1 | from sklearn.model_selection import train_test_split |
2 | 2 | import numpy as np |
3 | 3 | import pandas as pd |
| 4 | +from collections import defaultdict, OrderedDict |
| 5 | +import plotly.figure_factory as ff |
4 | 6 |
|
5 | 7 | class opts: |
6 | 8 | def __init__(self, maxIter, init): |
@@ -48,3 +50,74 @@ def MTL_data_extract(df, task_feat, target): |
48 | 50 | y = tmp1.loc[:, df.columns == target].values |
49 | 51 | Y.append(np.array(y)) |
50 | 52 | return X, Y |
| 53 | + |
| 54 | +def RFA(df, task, target, top=10): |
| 55 | + def reformat(cols, w, top=10): |
| 56 | + # Task -> coln |
| 57 | + RFA = OrderedDict() |
| 58 | + cols = np.array(cols) |
| 59 | + fet, task = w.shape |
| 60 | + total = {} |
| 61 | + all_tasks = [] |
| 62 | + for i in range(task): |
| 63 | + col = w[:,i].flatten() |
| 64 | + index = sorted(range(len(col)), key=lambda i: col[i], reverse=True)[:top] |
| 65 | + e = set(cols[index]) |
| 66 | + RFA["task {}".format(i+1)] = e |
| 67 | + all_tasks.append("task {}".format(i+1)) |
| 68 | + total = set.union(e, total) |
| 69 | + print("all top {} colns are {}".format(top, total)) |
| 70 | + # Coln -> tasks |
| 71 | + ret = defaultdict(lambda: []) |
| 72 | + # dataframe to visualize |
| 73 | + df_v = pd.DataFrame(False, index=list(total), columns=all_tasks) |
| 74 | + df_v2 = pd.DataFrame(None, index = list(total), columns=[str(p+1) for p in range(len(all_tasks))]) |
| 75 | + df_RFA = [] |
| 76 | + for t in all_tasks: |
| 77 | + df_RFA.append(list(RFA[t])) |
| 78 | + for i in total: |
| 79 | + count = 1 |
| 80 | + for k, v in RFA.items(): |
| 81 | + if i in v: |
| 82 | + ret[i].append(k) |
| 83 | + df_v[k][i]=True |
| 84 | + df_v2[str(count)][i] = int(k[-2:]) |
| 85 | + if(len(k)==6): |
| 86 | + df_v2[str(count)][i] = int(k[-1]) |
| 87 | + count+=1 |
| 88 | + return df_v, all_tasks, list(total), df_v2, df_RFA, RFA |
| 89 | + |
| 90 | + def sort_df(df): |
| 91 | + fet, tsk = df.values.shape |
| 92 | + ret = pd.DataFrame(None, columns=list(df.columns)) |
| 93 | + ind = list(df.index) |
| 94 | + seq = [] |
| 95 | + for i in range(tsk): |
| 96 | + for j in range(fet): |
| 97 | + if(np.count_nonzero(~np.isnan(list(df_v2.iloc[j].values)))==i+1): |
| 98 | + ret.loc[len(ret)] = df_v2.iloc[j].values |
| 99 | + seq.append(ind[j]) |
| 100 | + ret = ret.rename(index={i:j for i,j in zip(range(fet), seq)}) |
| 101 | + return ret |
| 102 | + |
| 103 | + |
| 104 | + def get_z_text(z, mp): |
| 105 | + x, y = z.shape |
| 106 | + ret = np.empty([x, y],dtype="S10") |
| 107 | + for i in range(x): |
| 108 | + for j in range(y): |
| 109 | + ret[i][j]=mp[z[i][j]] |
| 110 | + return ret.astype(str).tolist() |
| 111 | + |
| 112 | + all_col = (df.loc[:, (df.columns != target)&(df.columns != tasks)].columns).tolist() |
| 113 | + df_v, all_tasks, total, df_v2, RFA, index = reformat(all_col, mtl_clf.W, top=top) |
| 114 | + |
| 115 | + mp = {i+1:"Task_{}".format(i) for i in range(len(X))} |
| 116 | + mp[None] = '' |
| 117 | + mp[np.nan] = '' |
| 118 | + df_v3 = sort_df(df_v2) |
| 119 | + z_text = get_z_text(df_v3.values, mp) |
| 120 | + fig = ff.create_annotated_heatmap(z = df_v3.values.tolist(), annotation_text=z_text, y=list(df_v3.index)) |
| 121 | + fig.update_xaxes(showticklabels=False, showgrid=False) |
| 122 | + return fig |
| 123 | + |
0 commit comments