Skip to content

Commit 2c45a72

Browse files
authored
Merge pull request #42 from CSDLLab/3_6_draw_pic
04_16
2 parents b05eae8 + 2f3babf commit 2c45a72

File tree

15 files changed

+482
-216
lines changed

15 files changed

+482
-216
lines changed

config/branch_metric.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
{
22
"threshold_dis": 2,
33
"threshold_mode": 1,
4-
"metric_mode": 1,
54
"true_positive_type": 3,
65
"false_negative_type": 4,
76
"false_positive_type": 5
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"cfg1": [-5, 5, 100],
3+
"cfg2": [-5, 5, 100]
4+
}

config/schemas/branch_metric_schema.json

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44
"required": [
55
"threshold_dis",
66
"threshold_mode",
7-
"metric_mode",
87
"true_positive_type",
98
"false_negative_type",
109
"false_positive_type"
1110
],
1211
"properties": {
1312
"threshold_dis": {"type": "number", "exclusiveMinimum": 0},
1413
"threshold_mode": {"type": "number", "enum": [1, 2]},
15-
"metric_mode": {"type": "number", "enum": [1, 2]},
1614
"true_positive_type": {"type": "integer", "exclusiveMinimum": 0},
1715
"false_negative_type": {"type": "integer", "exclusiveMinimum": 0},
1816
"false_positive_type": {"type": "integer", "exclusiveMinimum": 0}

pyneval/cli/pyneval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,4 @@ def run(DEBUG=True):
306306

307307
# pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric
308308

309+
# pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric

pyneval/metric/branch_leaf_metric.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_colored_tree(test_node_list, gold_node_list, switch, km, color):
6666

6767
def score_point_distance(gold_tree: swc_node.SwcTree, test_tree: swc_node.SwcTree,
6868
test_node_list: list, gold_node_list: list,
69-
threshold_dis: float, color: list, metric_mode: int):
69+
threshold_dis: float, color: list):
7070
"""
7171
get minimum matching distance by running KM algorithm
7272
than calculte the return value according to matching result
@@ -78,9 +78,6 @@ def score_point_distance(gold_tree: swc_node.SwcTree, test_tree: swc_node.SwcTre
7878
threshold_dis: if the distance of two node are larger than this threshold,
7979
they are considered unlimited far
8080
color(List): color id of tp, fn, fp nodes
81-
metric_mode(1 or 2):
82-
mode = 1: distance between nodes are calculated as euclidean distance
83-
mode = 2: distance between nodes are calculated as distance on the gold tree
8481
Returns:
8582
gold_len(int): length of gold_node_list
8683
test_len(int): length of test_node_list
@@ -92,8 +89,8 @@ def score_point_distance(gold_tree: swc_node.SwcTree, test_tree: swc_node.SwcTre
9289
pt_cost: a composite value calculated by tp, fn, fp and threshold
9390
iso_node_num: number of nodes in test tree without parents or children
9491
"""
95-
test_gold_dict = point_match_utils.get_swc2swc_dicts(src_node_list=test_swc_tree.get_node_list(),
96-
tar_node_list=gold_swc_tree.get_node_list())
92+
test_gold_dict = point_match_utils.get_swc2swc_dicts(src_node_list=test_tree.get_node_list(),
93+
tar_node_list=gold_tree.get_node_list())
9794
# disgraph is a 2D ndarray store the distance between nodes in gold and test
9895
# test_node_list contains only branch or leaf nodes
9996
dis_graph, switch, test_len, gold_len = km_utils.get_dis_graph(gold_tree=gold_tree,
@@ -102,7 +99,7 @@ def score_point_distance(gold_tree: swc_node.SwcTree, test_tree: swc_node.SwcTre
10299
gold_node_list=gold_node_list,
103100
test_gold_dict=test_gold_dict,
104101
threshold_dis=threshold_dis,
105-
metric_mode=metric_mode)
102+
metric_mode=1)
106103
# create a KM object and calculate the minimum match
107104
km = km_utils.KM(maxn=max(test_len, gold_len)+10, nx=test_len, ny=gold_len, G=dis_graph)
108105
km.solve()
@@ -150,7 +147,6 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config):
150147
"""
151148
# read configs
152149
threshold_dis = config["threshold_dis"]
153-
metric_mode = config["metric_mode"]
154150
threshold_mode = config["threshold_mode"]
155151

156152
# in threshold mode 2, threshold is a multiple of the average length of edges
@@ -176,8 +172,7 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config):
176172
test_node_list=test_branch_swc_list,
177173
gold_node_list=gold_branch_swc_list,
178174
threshold_dis=threshold_dis,
179-
color=color,
180-
metric_mode=metric_mode)
175+
color=color)
181176

182177
branch_result = {
183178
"gold_len": branch_result_tuple[0],
@@ -199,8 +194,8 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config):
199194
gold_swc_tree = swc_node.SwcTree()
200195
test_swc_tree = swc_node.SwcTree()
201196

202-
test_swc_tree.load("../../data/test_data/topo_metric_data/gold_fake_data4.swc")
203-
gold_swc_tree.load("../../data/test_data/topo_metric_data/test_fake_data4.swc")
197+
gold_swc_tree.load("../../data/example_selected/a.swc")
198+
test_swc_tree.load("../../output/random_data/move/a/010/move_03.swc")
204199

205200
config = read_json.read_json("..\\..\\config\\branch_metric.json")
206201
config_schema = read_json.read_json("..\\..\\config\\schemas\\branch_metric_schema.json")
@@ -209,9 +204,6 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config):
209204
except Exception as e:
210205
raise Exception("[Error: ]Error in analyzing config json file")
211206

212-
config["metric_mode"] = 2
213-
config["threshold_dis"] = 1
214-
config["threshold_mode"] = 2
215207
branch_result = \
216208
branch_leaf_metric(test_swc_tree=test_swc_tree, gold_swc_tree=gold_swc_tree, config=config)
217209
print("---------------Result---------------")

pyneval/metric/length_metric.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,16 @@ def geometry():
179179
goldTree = swc_node.SwcTree()
180180
testTree = swc_node.SwcTree()
181181
sys.setrecursionlimit(10000000)
182-
goldTree.load("..\\..\\data\\test_data\\geo_metric_data\\gold_34_23_10.swc")
183-
testTree.load("..\\..\\data\\test_data\\geo_metric_data\\test_34_23_10.swc")
182+
testTree.load("E:\\00_project\\00_neural_reconstruction\\01_project\PyNeval\data\example_selected\\a.swc")
183+
goldTree.load("E:\\00_project\\00_neural_reconstruction\\01_project\PyNeval\output\\random_data\move\\a\\020\move_00.swc")
184184

185185
config = read_json.read_json("..\\..\\config\\length_metric.json")
186186
config_schema = read_json.read_json("..\\..\\config\\schemas\\length_metric_schema.json")
187-
188187
try:
189188
jsonschema.validate(config, config_schema)
190189
except Exception as e:
191190
raise Exception("[Error: ]Error in analyzing config json file")
192-
# config["detail_path"] = "..\\..\\output\\length_output\\length_metric_detail.swc"
191+
config["detail_path"] = "..\\..\\output\\length_output\\length_metric_detail.swc"
193192

194193
lm_res = length_metric(gold_swc_tree=goldTree,
195194
test_swc_tree=testTree,

pyneval/metric/link_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def link_metric(gold_swc_tree, test_swc_tree, config):
9393
gold_swc_tree = SwcTree()
9494
test_swc_tree = SwcTree()
9595
test_swc_tree.load("..\\..\\data\\test_data\\topo_metric_data\\gold_fake_data4.swc")
96-
gold_swc_tree.load("..\\..\\data\\test_data\\topo_metric_data\\test_fake_data4.swc")
96+
gold_swc_tree.load("..\\..\\data\\test_data\\topo_metric_data\\gold_fake_data4.swc")
9797
config = read_json("..\\..\\config\\link_metric.json")
9898
config_schema = read_json("..\\..\\config\\schemas\\link_metric_schema.json")
9999

pyneval/metric/ssd_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree,
141141
gold_tree = swc_node.SwcTree()
142142

143143
sys.setrecursionlimit(10000000)
144-
gold_tree.load("..\\..\\data\\test_data\\geo_metric_data\\gold_34_23_10.swc")
145-
test_tree.load("..\\..\\data\\test_data\\geo_metric_data\\test_34_23_10.swc")
144+
gold_tree.load("E:\\00_project\\00_neural_reconstruction\\01_project\PyNeval\data\example_selected\\a.swc")
145+
test_tree.load("E:\\00_project\\00_neural_reconstruction\\01_project\PyNeval\output\\random_data\move\\a\\020\move_00.swc")
146146

147147
config = read_json.read_json("..\\..\\config\\ssd_metric.json")
148148
config_schema = read_json.read_json("..\\..\\config\\schemas\\ssd_metric_schema.json")

pyneval/metric/utils/km_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_simple_lca_length(std_tree, test_gold_dict, node1, node2, switch):
2626

2727
lca_id = std_tree.get_lca(tmp_node1.get_id(), tmp_node2.get_id())
2828
if lca_id == -1:
29-
return DINF
29+
return config_utils.DINF
3030
lca_node = std_id_node_dict[lca_id]
3131
return tmp_node1.root_length + tmp_node2.root_length - 2 * lca_node.root_length
3232

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
demo_func = lambda x: x[0] ** 2 + (x[1] - 0.05) ** 2 + x[2] ** 2
2+
3+
# %% Do SA
4+
from sko.SA import SA
5+
6+
sa = SA(func=demo_func, x0=[1, 1, 0], T_max=1, T_min=1e-9, L=300, max_stay_counter=150)
7+
best_x, best_y = sa.run()
8+
print('best_x:', best_x, 'best_y', best_y)
9+
10+
# %% Plot the result
11+
import matplotlib.pyplot as plt
12+
import pandas as pd
13+
14+
plt.plot(pd.DataFrame(sa.best_y_history).cummin(axis=0))
15+
plt.show()
16+
17+
# %%
18+
from sko.SA import SAFast
19+
20+
sa_fast = SAFast(func=demo_func, x0=[1, 1, 1], T_max=1, T_min=1e-9, q=0.99, L=300, max_stay_counter=150)
21+
sa_fast.run()
22+
print('Fast Simulated Annealing: best_x is ', sa_fast.best_x, 'best_y is ', sa_fast.best_y)
23+
24+
# %%
25+
from sko.SA import SABoltzmann
26+
27+
sa_boltzmann = SABoltzmann(func=demo_func, x0=[1, 1, 1], T_max=1, T_min=1e-9, q=0.99, L=300, max_stay_counter=150)
28+
sa_boltzmann.run()
29+
print('Boltzmann Simulated Annealing: best_x is ', sa_boltzmann.best_x, 'best_y is ', sa_fast.best_y)
30+
31+
# %%
32+
from sko.SA import SACauchy
33+
34+
sa_cauchy = SACauchy(func=demo_func, x0=[1, 1, 1], T_max=1, T_min=1e-9, q=0.99, L=300, max_stay_counter=150)
35+
sa_cauchy.run()
36+
print('Cauchy Simulated Annealing: best_x is ', sa_cauchy.best_x, 'best_y is ', sa_cauchy.best_y)

0 commit comments

Comments
 (0)