Skip to content

Commit b2e1f69

Browse files
authored
Merge pull request #45 from CSDLLab/2020_05_07_down_sample
2020_05_07 add sample rate adjust in down sample
2 parents 9e30639 + cb5c5e0 commit b2e1f69

File tree

3 files changed

+21
-134
lines changed

3 files changed

+21
-134
lines changed

pyneval/tools/optimize/SA.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def run(self):
7676
stay_counter = 0
7777
while True:
7878
# loop L times under the same Temperature
79-
for i in range(self.L):
79+
i = 0
80+
while i < self.L:
8081
pool = mp.Pool(processes=CPU_CORE_NUM)
8182
res_y = []
8283
res_x = []
@@ -95,6 +96,7 @@ def run(self):
9596
pool.close()
9697
pool.join()
9798
for it in range(len(res_x)):
99+
i += 1
98100
x_new, y_new = res_y[it].get()
99101
print(x_new)
100102
print(y_new)

pyneval/tools/optimize/optimize.py

Lines changed: 1 addition & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -28,126 +28,6 @@
2828
METRIC_CONFIG_PATH = "../../../config/ssd_metric.json"
2929
LOG_PATH = "../../../output/optimization/neutu_log.txt"
3030

31-
# def fake_gaussian(mean=[0.0, 0.0], cov=[[0.0, 1.0], [1.0, 0.0]], pos=(0,0)):
32-
# rv = st.multivariate_normal(mean=[1.5, -1], cov=[[0.5, 0],[0, 0.5]])
33-
# rv2 = st.multivariate_normal(mean=[-0.1, 0.5], cov=[[1, 0],[0, 1]])
34-
# x = np.empty(shape=(1, 1, 2))
35-
# x[0,0,0] = pos[0]
36-
# x[0,0,1] = pos[1]
37-
# return rv.pdf(x) + rv2.pdf(x)
38-
39-
40-
# def naive_optimize(gold_tree, rcn_configs, current_config, metric_method, metric_configs):
41-
# """
42-
# A naive optimization, test every combination of test parameters
43-
# recursive parameter searching is used for traversing all the combinations.
44-
# """
45-
# clen = len(current_config)
46-
# if clen == len(rcn_configs):
47-
# # metric and get results
48-
# # # check the value of configs
49-
# # res = fake_gaussian(pos=configs)
50-
# rec_config = read_json.read_json(json_file_path=CONFIG_PATH)
51-
# current_config[0] = max(current_config[0], 0)
52-
# current_config[1] = max(current_config[1], 0)
53-
# print("[Info: ] minimalScoreAuto = {} minimalScoreSeed = {}".format(
54-
# current_config[0], current_config[1])
55-
# )
56-
# rec_config["trace"]["default"]["minimalScoreAuto"] = rcn_configs[0]
57-
# rec_config["trace"]["default"]["minimalScoreSeed"] = rcn_configs[1]
58-
#
59-
# # save new configs
60-
# read_json.save_json(json_file_path=CONFIG_PATH, data=rec_config)
61-
# REC_CMD = "{} --command --trace {} -o {} --config {}".format(
62-
# NEUTU_PATH, ORIGIN_PATH, TEST_PATH, CONFIG_PATH
63-
# )
64-
# print(REC_CMD)
65-
# try:
66-
# print("[Info: ] start tracing")
67-
# os.system(REC_CMD)
68-
# print("[Info: ] end tracing")
69-
# except:
70-
# raise Exception("[Error: ] error executing reconstruction")
71-
#
72-
# res_tree = swc_node.SwcTree()
73-
# gold_tree = swc_node.SwcTree()
74-
# res_tree.load(TEST_PATH)
75-
# gold_tree.load(GOLD_PATH)
76-
#
77-
# main_score = g_metric_method(gold_tree, res_tree, g_metric_configs)
78-
# # main_score = fake_gaussian(mean=[0.5, -0.3], cov=[[1, 0],[0, 1]], pos=current_config)
79-
# global g_score
80-
# if main_score > g_score:
81-
# g_score = main_score
82-
# print("max_score = {}, current_score = {}, x = {}, y = {}".format(
83-
# g_score, main_score, current_config[0], current_config[1])
84-
# )
85-
# return None
86-
#
87-
# for item in rcn_configs[clen]:
88-
# current_config.append(item)
89-
# naive_optimize(gold_tree=gold_tree,
90-
# rcn_method=rcn_method,
91-
# rcn_configs=rcn_configs,
92-
# current_config=current_config,
93-
# metric_method=metric_method,
94-
# metric_configs=metric_configs)
95-
# current_config.pop()
96-
# return None
97-
98-
99-
# def naive_main():
100-
# global g_metric_configs
101-
# global g_metric_method
102-
# g_metric_method = ssd_metric.ssd_metric
103-
# g_metric_configs = read_json.read_json(METRIC_CONFIG_PATH)
104-
#
105-
# z = np.zeros(shape=(50, 50))
106-
# for i in range(50):
107-
# for j in range(50):
108-
# rec_config = read_json.read_json(json_file_path=CONFIG_PATH)
109-
# rec_config["trace"]["default"]["minimalScoreAuto"] = 0.02 * i
110-
# rec_config["trace"]["default"]["minimalScoreSeed"] = 0.02 * j
111-
# read_json.save_json(json_file_path=CONFIG_PATH, data=rec_config)
112-
# REC_CMD = "{} --command --trace {} -o {} --config {}".format(
113-
# NEUTU_PATH, ORIGIN_PATH, TEST_PATH, CONFIG_PATH
114-
# )
115-
# print(REC_CMD)
116-
# try:
117-
# print("[Info: ] start tracing")
118-
# os.system(REC_CMD)
119-
# print("[Info: ] end tracing")
120-
# except:
121-
# raise Exception("[Error: ] error executing reconstruction")
122-
#
123-
# res_tree = swc_node.SwcTree()
124-
# gold_tree = swc_node.SwcTree()
125-
# res_tree.load(TEST_PATH)
126-
# gold_tree.load(GOLD_PATH)
127-
#
128-
# main_score = g_metric_method(gold_tree, res_tree, g_metric_configs)
129-
# print("[Info: ] call = {} minimalScoreAuto = {} minimalScoreSeed = {}".format(
130-
# main_score["recall"], 0.02 * i, 0.02 * j
131-
# ))
132-
# z[i][j] = main_score["recall"]
133-
#
134-
# # print(z)
135-
# x = np.linspace(0, 0.48, 50)
136-
# y = np.linspace(0, 0.48, 50)
137-
# X, Y = np.meshgrid(x, y)
138-
#
139-
# fig = plt.figure()
140-
# # 创建一个三维坐标轴
141-
# ax = plt.axes(projection='3d')
142-
# ax.contour3D(X, Y, z, 50, cmap='binary')
143-
# ax.set_xlabel('x')
144-
# ax.set_ylabel('y')
145-
# ax.set_zlabel('z')
146-
# ax.plot_surface(X, Y, z, rstride=1, cstride=1, cmap='viridis', edgecolor='none')
147-
# ax.set_title('surface')
148-
# plt.show()
149-
# print(X.shape)
150-
15131

15232
def SA_optimize(configs=None, test_name=None, lock=None):
15333
global g_metric_method
@@ -205,7 +85,7 @@ def main():
20585
configs = (0.3, 0.3, 0.35, 0.5)
20686
start = time.time()
20787
sa_fast = SAFast(func=SA_optimize,
208-
x0=configs, T_max=0.01, T_min=1e-5, q=0.96, L=20, max_stay_counter=30, upper=1, lower=0)
88+
x0=configs, T_max=0.01, T_min=1e-5, q=0.96, L=20, max_stay_counter=50, upper=1, lower=0)
20989
best_configs, best_value = sa_fast.run()
21090
print("[Info: ]best configs:\n"
21191
" origin minimalScoreAuto = {}\n"

pyneval/tools/re_sample.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def itp_ok(node=None, son=None, pa=None,
3535
return True
3636

3737

38-
def down_sample(swc_tree=None, rad_mul=1.50, center_dis=None, stage=0):
38+
def down_sample(swc_tree=None, rad_mul=1.50, center_dis=None, stage=0, k=1.0):
3939
stack = queue.LifoQueue()
4040
stack.put(swc_tree.root())
4141
down_pa = {}
@@ -62,9 +62,9 @@ def down_sample(swc_tree=None, rad_mul=1.50, center_dis=None, stage=0):
6262
son_dis, pa_dis, grand_dis = son.parent_distance(), node.distance(down_pa[node]), son.distance(pa)
6363

6464
# 确保针对采样率高的情况
65-
if stage == 0 and (son_dis > son.radius() + node.radius() or pa_dis > pa.radius() + node.radius()):
65+
if stage == 0 and (son_dis > k*(son.radius() + node.radius()) or pa_dis > k*(pa.radius() + node.radius())):
6666
continue
67-
if stage == 1 and (son_dis > son.radius() + node.radius() and pa_dis > pa.radius() + node.radius()):
67+
if stage == 1 and (son_dis > k*(son.radius() + node.radius()) and pa_dis > k*(pa.radius() + node.radius())):
6868
continue
6969
if itp_ok(node=node, son=son, pa=pa, rad_mul=rad_mul, center_dis=center_dis):
7070
is_active[node.get_id()] = False
@@ -77,18 +77,18 @@ def down_sample_swc_tree_command_line(swc_tree, config=None):
7777
rad_mul = config['rad_mul']
7878
center_dis = config['center_dis']
7979
stage = config['stage']
80-
return down_sample_swc_tree(swc_tree=swc_tree, rad_mul=rad_mul, center_dis=center_dis, stage=stage)
80+
return down_sample_swc_tree_stage(swc_tree=swc_tree, rad_mul=rad_mul, center_dis=center_dis, stage=stage)
8181

8282

83-
def down_sample_swc_tree(swc_tree, rad_mul=1.50, center_dis=None, stage=0):
83+
def down_sample_swc_tree_stage(swc_tree, rad_mul=1.50, center_dis=None, stage=0, k=1.0):
8484
'''
8585
:param swc_tree: the tree need to delete node
8686
:param rad_mul: defult=1.5
8787
:param center_dis: defult=None
88-
:param stage: 0: for 2 degree node, delete if one side is two close, 1: for 2 degree node, delete if two sides are two close
88+
:param stage: 0: for 2 degree node, delete if one side is too close, 1: for 2 degree node, delete if two sides are two close
8989
:return: swc_tree has changed in this function
9090
'''
91-
down_pa, is_activate = down_sample(swc_tree=swc_tree, rad_mul=rad_mul, center_dis=center_dis, stage=stage)
91+
down_pa, is_activate = down_sample(swc_tree=swc_tree, rad_mul=rad_mul, center_dis=center_dis, stage=stage, k=k)
9292
new_swc_tree = SwcTree()
9393
node_list = [node for node in PreOrderIter(swc_tree.root())]
9494
id_node_map = {-1: new_swc_tree.root()}
@@ -108,6 +108,12 @@ def down_sample_swc_tree(swc_tree, rad_mul=1.50, center_dis=None, stage=0):
108108
return new_swc_tree
109109

110110

111+
def down_sample_swc_tree(swc_tree, rad_mul=1.50, center_dis=None, k=1.0):
112+
res1 = down_sample_swc_tree_stage(swc_tree, rad_mul=rad_mul, center_dis=center_dis, stage=1, k=k)
113+
res2 = down_sample_swc_tree_stage(res1, rad_mul=rad_mul, center_dis=center_dis, stage=0, k=k)
114+
return res2
115+
116+
111117
def re_sample(swc_tree, son, pa, length_threshold):
112118
'''
113119
self recursive function, add node on the middle of edge son, pa
@@ -168,9 +174,8 @@ def up_sample_swc_tree(swc_tree, length_threshold=1.0):
168174

169175

170176
if __name__ == "__main__":
171-
file_name = "6144_12288_17664"
172177
swc_tree = SwcTree()
173-
swc_tree.load("D:\\02_project\\00_neural_tracing\\01_project\PyNeval\data\swc_cut_data\\{}.swc".format(file_name))
174-
up_sampled_swc_tree = up_sample_swc_tree(swc_tree=swc_tree, length_threshold=1.0)
175-
# up_sampled_swc_tree = down_sample_swc_tree(swc_tree=swc_tree)
176-
swc_save(up_sampled_swc_tree, "D:\\02_project\\00_neural_tracing\\01_project\PyNeval\output\\resample\\{}.swc".format(file_name))
178+
swc_tree.load("/home/zhanghan/01_project/Pyneval/data/raw/1-8.swc")
179+
180+
res_swc = down_sample_swc_tree(swc_tree=swc_tree, k=3)
181+
swc_save(res_swc, out_path="/home/zhanghan/01_project/Pyneval/output/down_sample.swc")

0 commit comments

Comments
 (0)