Skip to content

Commit c6c0c51

Browse files
authored
Merge pull request #43 from CSDLLab/0429_data
2020-4-29 add basic optimize tools
2 parents 2c45a72 + cb5dfdc commit c6c0c51

File tree

9 files changed

+481
-112
lines changed

9 files changed

+481
-112
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"title": "Skeletonization Schema",
3+
"type": "object",
4+
"properties": {
5+
"downsampleInterval": {
6+
"description": "Downsampling the stack before running skeletonization",
7+
"type": "array",
8+
"minItems": 3,
9+
"maxItems": 3,
10+
"items": {"type": "integer"}
11+
},
12+
"minimalLength": {
13+
"description": "Minimal length of the resulted branches",
14+
"type": "integer",
15+
"minimum": 0
16+
},
17+
"keepingSingleObject": {
18+
"description": "Keep an isolated object or not even if it is too small or short",
19+
"type": "boolean"
20+
},
21+
"rebase": {
22+
"description": "Reset the starting point to a terminal point?",
23+
"type": "boolean"
24+
},
25+
"fillingHole": {
26+
"description": "Reset the starting point to a terminal point?",
27+
"type": "boolean"
28+
},
29+
"maximalDistance": {
30+
"description": "Maximum distance to connect isolated branches.",
31+
"type": "number",
32+
"minimum": 0
33+
},
34+
"minimalObjectSize": {
35+
"description": "Minimal size of objects to skeletonize.",
36+
"type": "integer",
37+
"minimum": 0
38+
}
39+
}
40+
}
41+
42+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"trace": {"tag": "trace configuration", "default": {"minimalScoreAuto": 0.909970122661639, "minimalScoreManual": 1.0, "minimalScoreSeed": 0.9494785696149678, "minimalScore2d": 0.5304879686303995, "refit": false, "spTest": false, "crossoverTest": false, "tuneEnd": true, "edgePath": false, "enhanceMask": true, "seedMethod": 1, "recover": 1, "maxEucDist": 10}}}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"trace": {
3+
"tag": "trace configuration",
4+
"default": {
5+
"minimalScoreAuto": 0.3,
6+
"minimalScoreManual": 0.3,
7+
"minimalScoreSeed": 0.35,
8+
"minimalScore2d": 0.5,
9+
"refit": false,
10+
"spTest": false,
11+
"crossoverTest": false,
12+
"tuneEnd": true,
13+
"edgePath": false,
14+
"enhanceMask": true,
15+
"seedMethod": 1,
16+
"recover": 1,
17+
"maxEucDist": 10
18+
}
19+
}
20+
}

pyneval/io/read_json.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os,sys
33

4+
45
def read_json(json_file_path, DEBUG=False):
56
json_file_path = os.path.normpath(json_file_path)
67
if not os.path.isfile(json_file_path) or not (json_file_path[-5:] == ".json" or json_file_path[-5:] == ".JSON"):
@@ -19,6 +20,21 @@ def read_json(json_file_path, DEBUG=False):
1920
return data
2021

2122

23+
def save_json(json_file_path, data, DEBUG=False):
24+
json_file_path = os.path.normpath(json_file_path)
25+
if not(json_file_path[-5:] == ".json" or json_file_path[-5:] == ".JSON"):
26+
raise Exception("[Error: ] \" {} \" is not a json file. Wrong format".format(json_file_path))
27+
try:
28+
with open(json_file_path, 'w') as f:
29+
json.dump(data, f)
30+
if DEBUG:
31+
print(type(data))
32+
except:
33+
raise Exception("[Error: ] json file save error")
34+
return True
35+
36+
2237
if __name__ == "__main__":
23-
print(dir)
24-
read_json(r'{"method": 2,"thereshold": "default"')
38+
# read_json(r'{"method": 2,"thereshold": "default"')
39+
test_json = read_json("/home/zhanghan/01_project/Pyneval/config/schemas/branch_metric_schema.json")
40+
save_json("./test.json", test_json)

pyneval/metric/ssd_metric.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,11 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree,
141141
gold_tree = swc_node.SwcTree()
142142

143143
sys.setrecursionlimit(10000000)
144-
gold_tree.load("E:\\00_project\\00_neural_reconstruction\\01_project\PyNeval\data\example_selected\\a.swc")
145-
test_tree.load("E:\\00_project\\00_neural_reconstruction\\01_project\PyNeval\output\\random_data\move\\a\\020\move_00.swc")
144+
gold_tree.load("/home/zhanghan/01_project/Pyneval/data/optimation/temp_gold.swc")
145+
test_tree.load("/home/zhanghan/01_project/Pyneval/data/optimation/output/temp_test.swc")
146146

147-
config = read_json.read_json("..\\..\\config\\ssd_metric.json")
148-
config_schema = read_json.read_json("..\\..\\config\\schemas\\ssd_metric_schema.json")
147+
config = read_json.read_json("../../config/ssd_metric.json")
148+
config_schema = read_json.read_json("../../config/schemas/ssd_metric_schema.json")
149149

150150
try:
151151
jsonschema.validate(config, config_schema)

pyneval/tools/optimize/SA.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# @Time : 2019/8/17
4+
# @Author : github.com/guofei9987
5+
6+
import numpy as np
7+
from sko.base import SkoBase
8+
from sko.operators import mutation
9+
import multiprocessing as mp
10+
11+
CPU_CORE_NUM = 12
12+
13+
14+
class SimulatedAnnealingBase(SkoBase):
15+
"""
16+
DO SA(Simulated Annealing)
17+
18+
Parameters
19+
----------------
20+
func : function
21+
The func you want to do optimal
22+
n_dim : int
23+
number of variables of func
24+
x0 : array, shape is n_dim
25+
initial solution
26+
T_max :float
27+
initial temperature
28+
T_min : float
29+
end temperature
30+
L : int
31+
num of iteration under every temperature(Long of Chain)
32+
33+
Attributes
34+
----------------------
35+
36+
37+
Examples
38+
-------------
39+
See https://github.com/guofei9987/scikit-opt/blob/master/examples/demo_sa.py
40+
"""
41+
42+
def __init__(self, func, x0, T_max=100, T_min=1e-7, L=300, max_stay_counter=150, **kwargs):
43+
assert T_max > T_min > 0, 'T_max > T_min > 0'
44+
45+
self.func = func
46+
self.T_max = T_max # initial temperature
47+
self.T_min = T_min # end temperature
48+
self.L = int(L) # num of iteration under every temperature(also called Long of Chain)
49+
# stop if best_y stay unchanged over max_stay_counter times (also called cooldown time)
50+
self.max_stay_counter = max_stay_counter
51+
52+
self.n_dims = len(x0)
53+
54+
self.best_x = np.array(x0) # initial solution
55+
self.best_y = self.func(self.best_x)
56+
self.T = self.T_max
57+
self.iter_cycle = 0
58+
self.generation_best_X, self.generation_best_Y = [self.best_x], [self.best_y]
59+
# history reasons, will be deprecated
60+
self.best_x_history, self.best_y_history = self.generation_best_X, self.generation_best_Y
61+
62+
def get_new_x(self, x):
63+
u = np.random.uniform(-1, 1, size=self.n_dims)
64+
x_new = x + 20 * np.sign(u) * self.T * ((1 + 1.0 / self.T) ** np.abs(u) - 1.0)
65+
return x_new
66+
67+
def cool_down(self):
68+
self.T = self.T * 0.7
69+
70+
def isclose(self, a, b, rel_tol=1e-09, abs_tol=1e-30):
71+
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
72+
73+
def run(self):
74+
x_current, y_current = self.best_x, self.best_y
75+
stay_counter = 0
76+
while True:
77+
# loop L times under the same Temperature
78+
for i in range(self.L):
79+
x_new = self.get_new_x(x_current)
80+
y_new = self.func(x_new)
81+
print("[Info: ]i/L = {}/{}".format(
82+
i, self.L
83+
))
84+
85+
# Metropolis
86+
df = y_new - y_current
87+
if df < 0 or np.exp(-df / self.T) > np.random.rand():
88+
x_current, y_current = x_new, y_new
89+
if y_new < self.best_y:
90+
self.best_x, self.best_y = x_new, y_new
91+
print("[Info: ] iter_cycle = {} T = {} stay_counter = {}".format(
92+
self.iter_cycle, self.T, stay_counter
93+
))
94+
print("[Info: ]origin minimalScoreAuto = {}\n"
95+
" minimalScoreManual = {}\n"
96+
" minimalScoreSeed = {}\n"
97+
" minimalScore2d = {}".format(
98+
self.best_x[0], self.best_x[1], self.best_x[2], self.best_x[3]
99+
))
100+
self.iter_cycle += 1
101+
self.cool_down()
102+
self.generation_best_Y.append(self.best_y)
103+
self.generation_best_X.append(self.best_x)
104+
105+
# if best_y stay for max_stay_counter times, stop iteration
106+
if self.isclose(self.best_y_history[-1], self.best_y_history[-2]):
107+
stay_counter += 1
108+
else:
109+
stay_counter = 0
110+
111+
if self.T < self.T_min:
112+
stop_code = 'Cooled to final temperature'
113+
break
114+
if stay_counter > self.max_stay_counter:
115+
stop_code = 'Stay unchanged in the last {stay_counter} iterations'.format(stay_counter=stay_counter)
116+
break
117+
118+
return self.best_x, self.best_y
119+
120+
fit = run
121+
122+
123+
class SAFast(SimulatedAnnealingBase):
124+
'''
125+
u ~ Uniform(0, 1, size = d)
126+
y = sgn(u - 0.5) * T * ((1 + 1/T)**abs(2*u - 1) - 1.0)
127+
128+
xc = y * (upper - lower)
129+
x_new = x_old + xc
130+
131+
c = n * exp(-n * quench)
132+
T_new = T0 * exp(-c * k**quench)
133+
'''
134+
135+
def __init__(self, func, x0, T_max=100, T_min=1e-7, L=300, max_stay_counter=150, **kwargs):
136+
# nit parent class
137+
super().__init__(func, x0, T_max, T_min, L, max_stay_counter, **kwargs)
138+
self.m, self.n, self.quench = kwargs.get('m', 1), kwargs.get('n', 1), kwargs.get('quench', 1)
139+
# upper and down are range of the parameters.
140+
self.lower, self.upper = kwargs.get('lower', -10), kwargs.get('upper', 10)
141+
self.c = self.m * np.exp(-self.n * self.quench)
142+
143+
def get_new_x(self, x):
144+
"""randomly search for a new x point"""
145+
r = np.random.uniform(-1, 1, size=self.n_dims)
146+
xc = np.sign(r) * self.T * ((1 + 1.0 / self.T) ** np.abs(r) - 1.0)
147+
x_new = x + xc * (self.upper - self.lower)
148+
return x_new
149+
150+
def cool_down(self):
151+
self.T = self.T_max * np.exp(-self.c * self.iter_cycle ** self.quench)
152+
153+
154+
# SA_fast is the default
155+
SA = SAFast
156+

0 commit comments

Comments
 (0)