Skip to content

Commit 9e30639

Browse files
authored
Merge pull request #44 from CSDLLab/2020_05_07_parallel
2020_5_7 parallal optimize
2 parents c6c0c51 + 94d47f5 commit 9e30639

File tree

8 files changed

+160
-51
lines changed

8 files changed

+160
-51
lines changed
Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,20 @@
1-
{"trace": {"tag": "trace configuration", "default": {"minimalScoreAuto": 0.909970122661639, "minimalScoreManual": 1.0, "minimalScoreSeed": 0.9494785696149678, "minimalScore2d": 0.5304879686303995, "refit": false, "spTest": false, "crossoverTest": false, "tuneEnd": true, "edgePath": false, "enhanceMask": true, "seedMethod": 1, "recover": 1, "maxEucDist": 10}}}
1+
{
2+
"trace": {
3+
"tag": "trace configuration",
4+
"default": {
5+
"minimalScoreAuto": 0.3,
6+
"minimalScoreManual": 0.3,
7+
"minimalScoreSeed": 0.35,
8+
"minimalScore2d": 0.5,
9+
"refit": false,
10+
"spTest": false,
11+
"crossoverTest": false,
12+
"tuneEnd": true,
13+
"edgePath": false,
14+
"enhanceMask": true,
15+
"seedMethod": 1,
16+
"recover": 1,
17+
"maxEucDist": 10
18+
}
19+
}
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"trace": {
3+
"tag": "trace configuration",
4+
"default": {
5+
"minimalScoreAuto": 0.42039409,
6+
"minimalScoreManual": 1,
7+
"minimalScoreSeed": 1,
8+
"minimalScore2d": 0.59725383,
9+
"refit": false,
10+
"spTest": false,
11+
"crossoverTest": false,
12+
"tuneEnd": true,
13+
"edgePath": false,
14+
"enhanceMask": true,
15+
"seedMethod": 1,
16+
"recover": 1,
17+
"maxEucDist": 10
18+
}
19+
}
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"trace": {
3+
"tag": "trace configuration",
4+
"default": {
5+
"minimalScoreAuto": 0.39791192,
6+
"minimalScoreManual": 1,
7+
"minimalScoreSeed": 98459854,
8+
"minimalScore2d": 0.263612,
9+
"refit": false,
10+
"spTest": false,
11+
"crossoverTest": false,
12+
"tuneEnd": true,
13+
"edgePath": false,
14+
"enhanceMask": true,
15+
"seedMethod": 1,
16+
"recover": 1,
17+
"maxEucDist": 10
18+
}
19+
}
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"trace": {
3+
"tag": "trace configuration",
4+
"default": {
5+
"minimalScoreAuto": 0.45898798,
6+
"minimalScoreManual": 0.74510809,
7+
"minimalScoreSeed": 0.82922591,
8+
"minimalScore2d": 0.0,
9+
"refit": false,
10+
"spTest": false,
11+
"crossoverTest": false,
12+
"tuneEnd": true,
13+
"edgePath": false,
14+
"enhanceMask": true,
15+
"seedMethod": 1,
16+
"recover": 1,
17+
"maxEucDist": 10
18+
}
19+
}
20+
}

pyneval/tools/optimize/SA.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# -*- coding: utf-8 -*-
33
# @Time : 2019/8/17
44
# @Author : github.com/guofei9987
5+
import copy
56

67
import numpy as np
78
from sko.base import SkoBase
@@ -52,7 +53,7 @@ def __init__(self, func, x0, T_max=100, T_min=1e-7, L=300, max_stay_counter=150,
5253
self.n_dims = len(x0)
5354

5455
self.best_x = np.array(x0) # initial solution
55-
self.best_y = self.func(self.best_x)
56+
self.best_y = self.func(self.best_x, "test_init.swc")[1]
5657
self.T = self.T_max
5758
self.iter_cycle = 0
5859
self.generation_best_X, self.generation_best_Y = [self.best_x], [self.best_y]
@@ -76,18 +77,40 @@ def run(self):
7677
while True:
7778
# loop L times under the same Temperature
7879
for i in range(self.L):
79-
x_new = self.get_new_x(x_current)
80-
y_new = self.func(x_new)
81-
print("[Info: ]i/L = {}/{}".format(
82-
i, self.L
83-
))
84-
85-
# Metropolis
86-
df = y_new - y_current
87-
if df < 0 or np.exp(-df / self.T) > np.random.rand():
88-
x_current, y_current = x_new, y_new
89-
if y_new < self.best_y:
90-
self.best_x, self.best_y = x_new, y_new
80+
pool = mp.Pool(processes=CPU_CORE_NUM)
81+
res_y = []
82+
res_x = []
83+
lock = mp.Manager().Lock()
84+
for j in range(CPU_CORE_NUM):
85+
x_new = self.get_new_x(x_current)
86+
for k in range(len(x_new)):
87+
x_new[k] = max(x_new[k], 0)
88+
x_new[k] = min(x_new[k], 1)
89+
res_y.append(
90+
pool.apply_async(self.func, args=tuple([x_new, "test256_{}".format(j), lock]))
91+
)
92+
res_x.append(x_new)
93+
94+
print("[Info: ]i/L = {}/{}".format(i, self.L))
95+
pool.close()
96+
pool.join()
97+
for it in range(len(res_x)):
98+
x_new, y_new = res_y[it].get()
99+
print(x_new)
100+
print(y_new)
101+
# Metropolis
102+
df = y_new - y_current
103+
if df < 0 or np.exp(-df / self.T) > np.random.rand():
104+
x_current, y_current = x_new, y_new
105+
print("[Info: ] Jump success")
106+
if y_new < self.best_y:
107+
print("[Info: ] Update success")
108+
self.best_x = copy.deepcopy(x_new)
109+
self.best_y = y_new
110+
break
111+
print("[Info: ] best x = {}".format(self.best_x))
112+
print("[Info: ] best y = {}".format(self.best_y))
113+
91114
print("[Info: ] iter_cycle = {} T = {} stay_counter = {}".format(
92115
self.iter_cycle, self.T, stay_counter
93116
))

pyneval/tools/optimize/optimize.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
g_gold_tree = None
1717
g_tar_tree = None
1818
g_rcn_method = None
19+
g_rcn_config = None
1920
g_metric_method = None
2021
g_metric_configs = None
2122

2223
NEUTU_PATH = "../../../../../00_program_file/00_neutu/bin/neutu"
2324
ORIGIN_PATH = "../../../data/optimation/test1/test1_test.tif"
2425
GOLD_PATH = "../../../data/optimation/test1/test1_gold.swc"
25-
TEST_PATH = "../../../data/optimation/output/test1_test.swc"
26-
CONFIG_PATH = "../../../config/fake_reconstruction_configs/test.json"
26+
TEST_PATH = "../../../data/optimation/output/"
27+
CONFIG_PATH = "../../../config/fake_reconstruction_configs/"
2728
METRIC_CONFIG_PATH = "../../../config/ssd_metric.json"
2829
LOG_PATH = "../../../output/optimization/neutu_log.txt"
2930

@@ -148,35 +149,25 @@
148149
# print(X.shape)
149150

150151

151-
def SA_optimize(configs, lock=None):
152-
global g_gold_tree
152+
def SA_optimize(configs=None, test_name=None, lock=None):
153153
global g_metric_method
154154
global g_metric_configs
155+
global g_rcn_config
155156

156-
# # check the value of configs
157-
# res = fake_gaussian(pos=configs)
157+
LOC_CONFIG_PATH = os.path.join(CONFIG_PATH, test_name+".json")
158+
LOC_TEST_PATH = os.path.join(TEST_PATH, test_name+".swc")
159+
rec_config = copy.deepcopy(g_rcn_config)
158160

159-
rec_config = copy.deepcopy(g_metric_configs)
161+
if configs is not None:
162+
rec_config["trace"]["default"]["minimalScoreAuto"] = configs[0]
163+
rec_config["trace"]["default"]["minimalScoreManual"] = configs[1]
164+
rec_config["trace"]["default"]["minimalScoreSeed"] = configs[2]
165+
rec_config["trace"]["default"]["minimalScore2d"] = configs[3]
160166

161-
for i in range(len(configs)):
162-
configs[i] = max(configs[i], 0)
163-
configs[i] = min(configs[i], 1)
167+
read_json.save_json(LOC_CONFIG_PATH, rec_config)
164168

165-
rec_config["trace"]["default"]["minimalScoreAuto"] = configs[0]
166-
rec_config["trace"]["default"]["minimalScoreManual"] = configs[1]
167-
rec_config["trace"]["default"]["minimalScoreSeed"] = configs[2]
168-
rec_config["trace"]["default"]["minimalScore2d"] = configs[3]
169-
170-
# save new configs
171-
if lock is not None:
172-
lock.acquire()
173-
try:
174-
read_json.save_json(json_file_path=CONFIG_PATH, data=rec_config)
175-
finally:
176-
if lock is not None:
177-
lock.release()
178169
REC_CMD = "{} --command --trace {} -o {} --config {} > {}".format(
179-
NEUTU_PATH, ORIGIN_PATH, TEST_PATH, CONFIG_PATH, LOG_PATH
170+
NEUTU_PATH, ORIGIN_PATH, LOC_TEST_PATH, LOC_CONFIG_PATH, LOG_PATH
180171
)
181172
try:
182173
os.system(REC_CMD)
@@ -185,7 +176,7 @@ def SA_optimize(configs, lock=None):
185176

186177
res_tree = swc_node.SwcTree()
187178
gold_tree = swc_node.SwcTree()
188-
res_tree.load(TEST_PATH)
179+
res_tree.load(os.path.join(TEST_PATH, test_name+".swc"))
189180
gold_tree.load(GOLD_PATH)
190181

191182
if lock is not None:
@@ -195,25 +186,26 @@ def SA_optimize(configs, lock=None):
195186
finally:
196187
if lock is not None:
197188
lock.release()
189+
score = (main_score["recall"] + main_score["precision"])/2
190+
print("[Info: ] ssd loss = {}".format(score))
198191

199-
print("[Info: ] ssd loss = {}".format(
200-
(main_score["recall"] + main_score["precision"])/2)
201-
)
202-
203-
return -(main_score["recall"] + main_score["precision"])/2
192+
return configs, -score
204193

205194

206195
def main():
207196
global g_metric_configs
208197
global g_metric_method
198+
global g_rcn_config
209199
g_metric_method = ssd_metric.ssd_metric
210200
g_metric_configs = read_json.read_json(METRIC_CONFIG_PATH)
201+
g_rcn_config = read_json.read_json(os.path.join(CONFIG_PATH, "test.json"))
202+
211203
# optimize with SA
212204
# configs here is the config of the reconstruction
213205
configs = (0.3, 0.3, 0.35, 0.5)
214206
start = time.time()
215207
sa_fast = SAFast(func=SA_optimize,
216-
x0=configs, T_max=0.01, T_min=1e-5, q=0.96, L=30, max_stay_counter=20, upper=1, lower=0)
208+
x0=configs, T_max=0.01, T_min=1e-5, q=0.96, L=20, max_stay_counter=30, upper=1, lower=0)
217209
best_configs, best_value = sa_fast.run()
218210
print("[Info: ]best configs:\n"
219211
" origin minimalScoreAuto = {}\n"
@@ -224,6 +216,11 @@ def main():
224216
" time = {}\n" .format(
225217
best_configs[0], best_configs[1], best_configs[2], best_configs[3], best_value, time.time() - start
226218
))
219+
g_rcn_config["trace"]["default"]["minimalScoreAuto"] = best_configs[0]
220+
g_rcn_config["trace"]["default"]["minimalScoreManual"] = best_configs[1]
221+
g_rcn_config["trace"]["default"]["minimalScoreSeed"] = best_configs[2]
222+
g_rcn_config["trace"]["default"]["minimalScore2d"] = best_configs[3]
223+
read_json.save_json(os.path.join(CONFIG_PATH, "best_x_{}.json".format(time.time())), g_rcn_config)
227224
# plot the result.
228225
plt.plot(pd.DataFrame(sa_fast.best_y_history).cummin(axis=0))
229226
plt.xlabel("iterations")
@@ -234,3 +231,9 @@ def main():
234231

235232
if __name__ == "__main__":
236233
main()
234+
235+
# g_metric_method = ssd_metric.ssd_metric
236+
# g_metric_configs = read_json.read_json(METRIC_CONFIG_PATH)
237+
# g_rcn_config = read_json.read_json(os.path.join(CONFIG_PATH, "test3best.json"))
238+
#
239+
# SA_optimize(test_name="test3best")

pyneval/tools/swc_cut.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pyneval.io import swc_writer
44

55
SWC_PATH = "../../data/optimation/gold.swc"
6-
OUTPUT_PATH = "../../data/optimation/test3_gold.swc"
6+
OUTPUT_PATH = "../../data/optimation/test4_gold.swc"
77

88

99
def cut_swc_rectangle(swc_tree, LFD_pos, RBT_pos):
@@ -34,7 +34,11 @@ def cut_swc_rectangle(swc_tree, LFD_pos, RBT_pos):
3434
# load origin swc file
3535
swc_tree.load(SWC_PATH)
3636

37-
cut_swc_rectangle(swc_tree, (0, 0, 0), (256, 256, 256))
37+
cut_swc_rectangle(swc_tree, (260, 30, 125), (324, 94, 189))
3838
swc_tree.get_node_list(update=True)
39+
for node in swc_tree.get_node_list():
40+
node.set_x(node.get_x()-260)
41+
node.set_y(node.get_y()-30)
42+
node.set_z(node.get_z()-125)
3943
# the result will be saved in:
4044
swc_writer.swc_save(swc_tree=swc_tree, out_path=OUTPUT_PATH)

pyneval/tools/tiff_cut.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
from pyneval.metric.utils.tiff_utils import front_expend_step
77
from pyneval.model.swc_node import SwcTree
88

9-
TIFF_PATH = "../../data/optimation/6656_gold.tif"
9+
TIFF_PATH = "../../data/optimation/6656_test.tif"
1010

1111

12-
def tiff_cut(tiff_data, LDF = (0, 0, 0), len = (256, 256, 256)):
12+
def tiff_cut(tiff_data, LDF = (0, 0, 0), len = (64, 64, 64)):
1313
res = tiff_data[LDF[0]:LDF[0]+len[0], LDF[1]:LDF[1]+len[1], LDF[2]:LDF[2]+len[2]]
1414
return res
1515

1616

1717
if __name__ == "__main__":
1818
tiff_file = TiffFile.imread(TIFF_PATH)
19-
data = tiff_cut(tiff_file, LDF=(0, 0, 0))
20-
TiffFile.imsave("../../data/optimation/test3_gold.tif", data)
19+
data = tiff_cut(tiff_file, LDF=(125, 30, 260))
20+
TiffFile.imsave("../../data/optimation/test4_test.tif", data)
2121
# print("??")
2222

0 commit comments

Comments
 (0)