|
28 | 28 | METRIC_CONFIG_PATH = "../../../config/ssd_metric.json" |
29 | 29 | LOG_PATH = "../../../output/optimization/neutu_log.txt" |
30 | 30 |
|
31 | | -# def fake_gaussian(mean=[0.0, 0.0], cov=[[0.0, 1.0], [1.0, 0.0]], pos=(0,0)): |
32 | | -# rv = st.multivariate_normal(mean=[1.5, -1], cov=[[0.5, 0],[0, 0.5]]) |
33 | | -# rv2 = st.multivariate_normal(mean=[-0.1, 0.5], cov=[[1, 0],[0, 1]]) |
34 | | -# x = np.empty(shape=(1, 1, 2)) |
35 | | -# x[0,0,0] = pos[0] |
36 | | -# x[0,0,1] = pos[1] |
37 | | -# return rv.pdf(x) + rv2.pdf(x) |
38 | | - |
39 | | - |
40 | | -# def naive_optimize(gold_tree, rcn_configs, current_config, metric_method, metric_configs): |
41 | | -# """ |
42 | | -# A naive optimization, test every combination of test parameters |
43 | | -# recursive parameter searching is used for traversing all the combinations. |
44 | | -# """ |
45 | | -# clen = len(current_config) |
46 | | -# if clen == len(rcn_configs): |
47 | | -# # metric and get results |
48 | | -# # # check the value of configs |
49 | | -# # res = fake_gaussian(pos=configs) |
50 | | -# rec_config = read_json.read_json(json_file_path=CONFIG_PATH) |
51 | | -# current_config[0] = max(current_config[0], 0) |
52 | | -# current_config[1] = max(current_config[1], 0) |
53 | | -# print("[Info: ] minimalScoreAuto = {} minimalScoreSeed = {}".format( |
54 | | -# current_config[0], current_config[1]) |
55 | | -# ) |
56 | | -# rec_config["trace"]["default"]["minimalScoreAuto"] = rcn_configs[0] |
57 | | -# rec_config["trace"]["default"]["minimalScoreSeed"] = rcn_configs[1] |
58 | | -# |
59 | | -# # save new configs |
60 | | -# read_json.save_json(json_file_path=CONFIG_PATH, data=rec_config) |
61 | | -# REC_CMD = "{} --command --trace {} -o {} --config {}".format( |
62 | | -# NEUTU_PATH, ORIGIN_PATH, TEST_PATH, CONFIG_PATH |
63 | | -# ) |
64 | | -# print(REC_CMD) |
65 | | -# try: |
66 | | -# print("[Info: ] start tracing") |
67 | | -# os.system(REC_CMD) |
68 | | -# print("[Info: ] end tracing") |
69 | | -# except: |
70 | | -# raise Exception("[Error: ] error executing reconstruction") |
71 | | -# |
72 | | -# res_tree = swc_node.SwcTree() |
73 | | -# gold_tree = swc_node.SwcTree() |
74 | | -# res_tree.load(TEST_PATH) |
75 | | -# gold_tree.load(GOLD_PATH) |
76 | | -# |
77 | | -# main_score = g_metric_method(gold_tree, res_tree, g_metric_configs) |
78 | | -# # main_score = fake_gaussian(mean=[0.5, -0.3], cov=[[1, 0],[0, 1]], pos=current_config) |
79 | | -# global g_score |
80 | | -# if main_score > g_score: |
81 | | -# g_score = main_score |
82 | | -# print("max_score = {}, current_score = {}, x = {}, y = {}".format( |
83 | | -# g_score, main_score, current_config[0], current_config[1]) |
84 | | -# ) |
85 | | -# return None |
86 | | -# |
87 | | -# for item in rcn_configs[clen]: |
88 | | -# current_config.append(item) |
89 | | -# naive_optimize(gold_tree=gold_tree, |
90 | | -# rcn_method=rcn_method, |
91 | | -# rcn_configs=rcn_configs, |
92 | | -# current_config=current_config, |
93 | | -# metric_method=metric_method, |
94 | | -# metric_configs=metric_configs) |
95 | | -# current_config.pop() |
96 | | -# return None |
97 | | - |
98 | | - |
99 | | -# def naive_main(): |
100 | | -# global g_metric_configs |
101 | | -# global g_metric_method |
102 | | -# g_metric_method = ssd_metric.ssd_metric |
103 | | -# g_metric_configs = read_json.read_json(METRIC_CONFIG_PATH) |
104 | | -# |
105 | | -# z = np.zeros(shape=(50, 50)) |
106 | | -# for i in range(50): |
107 | | -# for j in range(50): |
108 | | -# rec_config = read_json.read_json(json_file_path=CONFIG_PATH) |
109 | | -# rec_config["trace"]["default"]["minimalScoreAuto"] = 0.02 * i |
110 | | -# rec_config["trace"]["default"]["minimalScoreSeed"] = 0.02 * j |
111 | | -# read_json.save_json(json_file_path=CONFIG_PATH, data=rec_config) |
112 | | -# REC_CMD = "{} --command --trace {} -o {} --config {}".format( |
113 | | -# NEUTU_PATH, ORIGIN_PATH, TEST_PATH, CONFIG_PATH |
114 | | -# ) |
115 | | -# print(REC_CMD) |
116 | | -# try: |
117 | | -# print("[Info: ] start tracing") |
118 | | -# os.system(REC_CMD) |
119 | | -# print("[Info: ] end tracing") |
120 | | -# except: |
121 | | -# raise Exception("[Error: ] error executing reconstruction") |
122 | | -# |
123 | | -# res_tree = swc_node.SwcTree() |
124 | | -# gold_tree = swc_node.SwcTree() |
125 | | -# res_tree.load(TEST_PATH) |
126 | | -# gold_tree.load(GOLD_PATH) |
127 | | -# |
128 | | -# main_score = g_metric_method(gold_tree, res_tree, g_metric_configs) |
129 | | -# print("[Info: ] call = {} minimalScoreAuto = {} minimalScoreSeed = {}".format( |
130 | | -# main_score["recall"], 0.02 * i, 0.02 * j |
131 | | -# )) |
132 | | -# z[i][j] = main_score["recall"] |
133 | | -# |
134 | | -# # print(z) |
135 | | -# x = np.linspace(0, 0.48, 50) |
136 | | -# y = np.linspace(0, 0.48, 50) |
137 | | -# X, Y = np.meshgrid(x, y) |
138 | | -# |
139 | | -# fig = plt.figure() |
140 | | -# # 创建一个三维坐标轴 |
141 | | -# ax = plt.axes(projection='3d') |
142 | | -# ax.contour3D(X, Y, z, 50, cmap='binary') |
143 | | -# ax.set_xlabel('x') |
144 | | -# ax.set_ylabel('y') |
145 | | -# ax.set_zlabel('z') |
146 | | -# ax.plot_surface(X, Y, z, rstride=1, cstride=1, cmap='viridis', edgecolor='none') |
147 | | -# ax.set_title('surface') |
148 | | -# plt.show() |
149 | | -# print(X.shape) |
150 | | - |
151 | 31 |
|
152 | 32 | def SA_optimize(configs=None, test_name=None, lock=None): |
153 | 33 | global g_metric_method |
@@ -205,7 +85,7 @@ def main(): |
205 | 85 | configs = (0.3, 0.3, 0.35, 0.5) |
206 | 86 | start = time.time() |
207 | 87 | sa_fast = SAFast(func=SA_optimize, |
208 | | - x0=configs, T_max=0.01, T_min=1e-5, q=0.96, L=20, max_stay_counter=30, upper=1, lower=0) |
| 88 | + x0=configs, T_max=0.01, T_min=1e-5, q=0.96, L=20, max_stay_counter=50, upper=1, lower=0) |
209 | 89 | best_configs, best_value = sa_fast.run() |
210 | 90 | print("[Info: ]best configs:\n" |
211 | 91 | " origin minimalScoreAuto = {}\n" |
|
0 commit comments