Skip to content

Commit 5e0bae4

Browse files
authored
Merge pull request #50 from CSDLLab/2021_0602_get_output_organized
reorganized the pyneval entry file, unify output and detail
2 parents 6ca8154 + cbf449e commit 5e0bae4

File tree

5 files changed

+141
-154
lines changed

5 files changed

+141
-154
lines changed

pyneval/cli/pyneval.py

Lines changed: 111 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,57 @@
44
import jsonschema
55
import pyneval
66
from pyneval.io.read_swc import read_swc_trees
7-
from pyneval.io.read_json import read_json
7+
from pyneval.io import read_json
88
from pyneval.io.swc_writer import swc_save
99
from pyneval.io.read_tiff import read_tiffs
10-
from pyneval.metric.diadem_metric import diadem_metric
11-
from pyneval.metric.length_metric import length_metric
12-
from pyneval.metric.volume_metric import volume_metric
13-
from pyneval.metric.branch_leaf_metric import branch_leaf_metric
14-
from pyneval.metric.link_metric import link_metric
10+
from pyneval.metric import diadem_metric
11+
from pyneval.metric import length_metric
12+
from pyneval.metric import volume_metric
13+
from pyneval.metric import branch_leaf_metric
14+
from pyneval.metric import link_metric
1515
from pyneval.metric import ssd_metric
1616

1717
METRICS = {
1818
'diadem_metric': {
1919
'config': "diadem_metric.json",
2020
'description': "DIADEM metric (https://doi.org/10.1007/s12021-011-9117-y)",
2121
'alias': ['DM'],
22+
'method': diadem_metric.diadem_metric,
2223
'public': True
2324
},
2425
'ssd_metric': {
2526
'config': "ssd_metric.json",
2627
'description': "minimum square error between up-sampled gold and test trees",
2728
'alias': ['SM'],
29+
'method': ssd_metric.ssd_metric,
2830
'public': True
2931
},
3032
'length_metric': {
3133
'config': "length_metric.json",
3234
'description': "length of matched branches and fibers",
3335
'alias': ['ML'],
36+
'method': length_metric.length_metric,
3437
'public': True
3538
},
3639
'volume_metric': {
3740
'config': "volume_metric.json",
3841
'description': "volume overlap",
3942
'alias': ['VM'],
43+
'method': volume_metric.volume_metric,
4044
'public': False
4145
},
4246
'branch_metric': {
4347
'config': "branch_metric.json",
4448
'description': "quality of critical points",
4549
'alias': ['BM'],
50+
'method': branch_leaf_metric.branch_leaf_metric,
4651
'public': True
4752
},
4853
'link_metric': {
4954
'config': "link_metric.json",
5055
'description': "",
5156
'alias': ['LM'],
57+
'method': link_metric.link_metric,
5258
'public': False
5359
},
5460
}
@@ -90,24 +96,27 @@ def get_metric_config_schema_path(metric, root_dir):
9096
schema_dir = os.path.join(config_dir, "schemas")
9197
return os.path.join(schema_dir, get_metric_config(metric)['config'][:-5]+"_schema.json")
9298

99+
def get_metric_method(metric):
100+
return get_metric_config(metric)['method']
101+
93102
def read_parameters():
94103
parser = argparse.ArgumentParser(
95104
description="pyneval 1.0"
96105
)
97106

98-
parser.add_argument(
99-
"--test",
100-
"-T",
101-
help="a list of SWC files for evaluation",
102-
required=False,
103-
nargs='*',
104-
)
105107
parser.add_argument(
106108
"--gold",
107109
"-G",
108110
help="path to the gold-standard SWC file",
109111
required=True
110112
)
113+
parser.add_argument(
114+
"--test",
115+
"-T",
116+
help="a list of SWC files for evaluation",
117+
required=True,
118+
nargs='*',
119+
)
111120
parser.add_argument(
112121
"--metric",
113122
"-M",
@@ -117,174 +126,137 @@ def read_parameters():
117126
parser.add_argument(
118127
"--output",
119128
"-O",
120-
help="path to the output file (output to screen if not specified)",
129+
help="metric output path, including different scores of the metric",
121130
required=False
122131
)
123132
parser.add_argument(
124-
"--config",
125-
"-C",
126-
help="custom configuration file for the specified metric",
133+
"--detail",
134+
"-D",
135+
help="detail \"type\" marked for gold/test SWC file, including marked swc trees",
127136
required=False
128137
)
129138
parser.add_argument(
130-
"--reverse",
131-
"-R",
132-
help="output the answer when we switch the gold and test tree",
139+
"--config",
140+
"-C",
141+
help="custom configuration file for the specified metric",
133142
required=False
134143
)
135144
parser.add_argument(
136145
"--debug",
137-
"-D",
138146
help="print debug info or not",
139147
required=False
140148
)
141149
return parser.parse_args()
142150

143151

144-
# command program
145-
def run(DEBUG=True):
146-
# init path parameter
147-
abs_dir = os.path.abspath("")
152+
def init(abs_dir):
148153
sys.path.append(abs_dir)
149154
sys.path.append(os.path.join(abs_dir, "src"))
150155
sys.path.append(os.path.join(abs_dir, "test"))
151156
sys.setrecursionlimit(1000000)
152157

153-
# read parameter
154-
try:
155-
args = read_parameters()
156-
except:
157-
return 1
158-
159-
# set config
160-
# gold/test files
161-
if args.test is None:
162-
test_swc_files = []
163-
else:
164-
test_swc_files = [os.path.join(abs_dir, path) for path in args.test]
165-
gold_swc_file = os.path.join(abs_dir, args.gold)
166-
gold_file_name = os.path.basename(gold_swc_file)
167158

168-
# reverse
169-
reverse = args.reverse
170-
if reverse is None:
171-
reverse = True
159+
def set_configs(abs_dir, args):
160+
# argument: gold
161+
gold_swc_path = os.path.join(abs_dir, args.gold)
162+
if not (os.path.isfile(gold_swc_path) and gold_swc_path[-4:] == ".swc"):
163+
raise Exception("[Error: ] gold standard file is not a swc file")
164+
gold_swc_tree = read_swc_trees(gold_swc_path)[0] # SwcTree
172165

173-
# metric
166+
# argument: metric
174167
metric = get_root_metric(args.metric)
175168
if not metric:
176169
print("\nERROR: The metric '{}' is not supported.".format(args.metric))
177170
print("\nValid options for --metric:\n")
178171
print(get_metric_summary(True))
179172
return 1
180173

181-
# output path
182-
output_dest = args.output
183-
if output_dest is not None:
184-
output_dest = os.path.join(abs_dir, output_dest)
174+
# argument: test
175+
test_swc_paths = [os.path.join(abs_dir, path) for path in args.test]
176+
test_swc_trees = []
177+
# read test trees
178+
if metric in ['volume_metric', 'VM']:
179+
for file in test_swc_paths:
180+
test_swc_trees += read_tiffs(file)
181+
else:
182+
for file in test_swc_paths:
183+
test_swc_trees += read_swc_trees(file)
184+
185+
# info: how many trees read
186+
print("There are {} test image(s)".format(len(test_swc_trees)))
187+
188+
# argument: output
189+
output_dir = None
190+
if args.output:
191+
output_dir = os.path.join(abs_dir, args.output)
185192

186-
# config
193+
# argument: detail
194+
detail_dir = None
195+
if args.detail:
196+
detail_dir = os.path.join(abs_dir, args.detail)
197+
198+
# argument: config
187199
config_path = args.config
188200
if config_path is None:
189201
config_path = get_metric_config_path(metric, abs_dir)
190202
config_schema_path = get_metric_config_schema_path(metric, abs_dir)
191-
config = read_json(config_path)
192-
config_schema = read_json(config_schema_path)
203+
204+
config = read_json.read_json(config_path)
205+
config_schema = read_json.read_json(config_schema_path)
193206
try:
194207
jsonschema.validate(config, config_schema)
195208
except Exception:
196209
raise Exception("[Error: ]Error in analyzing config json file")
197210

198-
test_swc_trees, test_tiffs = [], []
199-
# read test trees, gold trees and configs
200-
if metric in ['volume_metric', 'VM']:
201-
for file in test_swc_files:
202-
test_tiffs += read_tiffs(file)
211+
# argument: debug
212+
is_debug = args.debug
213+
214+
return gold_swc_tree, test_swc_trees, metric, output_dir, detail_dir, config, is_debug
215+
216+
217+
def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir, file_name_extra=""):
218+
metric_method = get_metric_method(metric)
219+
test_swc_name = test_swc_tree.get_name()
220+
gold_swc_name = gold_swc_tree.get_name()
221+
222+
result = metric_method(gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree, config=config)
223+
224+
print("---------------Result---------------")
225+
for key in result:
226+
print("{} = {}".format(key.ljust(15, ' '), result[key]))
227+
print("----------------End-----------------\n")
228+
229+
if file_name_extra == "reverse":
230+
file_name = gold_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc"
203231
else:
204-
for file in test_swc_files:
205-
test_swc_trees += read_swc_trees(file)
232+
file_name = test_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc"
206233

207-
gold_swc_trees = read_swc_trees(gold_swc_file)
234+
if detail_dir:
235+
swc_save(swc_tree=gold_swc_tree,
236+
out_path=os.path.join(detail_dir, file_name))
208237

209-
# info: how many trees read
210-
print("There are {} test image(s) and {} gold image(s)".format(len(test_swc_trees), len(gold_swc_trees)))
211-
if len(gold_swc_trees) == 0:
212-
raise Exception("[Error: ] No gold image detected")
213-
if len(gold_swc_trees) > 1:
214-
print("[Warning: ] More than one gold image detected, only the first one will be used")
215-
216-
# entries to different metrics
217-
gold_swc_treeroot = gold_swc_trees[0]
218-
for test_tiff in test_tiffs:
219-
if metric == "volume_metric":
220-
volume_result = volume_metric(tiff_test=test_tiff, swc_gold=gold_swc_treeroot, config=config)
221-
print(volume_result["recall"])
222-
223-
for test_swc_treeroot in test_swc_trees:
224-
if metric == "diadem_metric":
225-
diadem_res = diadem_metric(swc_test_tree=test_swc_treeroot,
226-
swc_gold_tree=gold_swc_treeroot,
227-
config=config)
228-
print("score = {}".format(diadem_res["final_score"]))
229-
if reverse:
230-
rev_diadem_res = diadem_metric(swc_test_tree=gold_swc_treeroot,
231-
swc_gold_tree=test_swc_treeroot,
232-
config=config)
233-
print("rev_score = {}".format(rev_diadem_res["final_score"]))
234-
235-
if metric == "ssd_metric":
236-
ssd_res = ssd_metric.ssd_metric(gold_swc_treeroot, test_swc_treeroot, config)
237-
print("ssd score = {}\n"
238-
"recall = {}%\n"
239-
"precision = {}%".format(round(ssd_res["avg_score"], 2),
240-
round(ssd_res["recall"] * 100, 2),
241-
round(ssd_res["precision"] * 100, 2)))
242-
243-
if metric == "length_metric":
244-
lm_res = length_metric(gold_swc_treeroot, test_swc_treeroot, config)
245-
print("Recall = {} Precision = {}".format(lm_res["recall"], lm_res["precision"]))
246-
247-
if output_dest:
248-
swc_save(test_swc_treeroot, output_dest)
249-
if reverse:
250-
if "detail" in config:
251-
config["detail"] = config["detail"][:-4] + "_reverse.swc"
252-
lm_res = length_metric(test_swc_treeroot, gold_swc_treeroot, config)
253-
print("Recall = {} Precision = {}".format(lm_res["recall"], lm_res["precision"]))
254-
if output_dest:
255-
swc_save(gold_swc_treeroot, output_dest[:-4]+"_reverse.swc")
256-
if metric == "branch_metric":
257-
branch_res = branch_leaf_metric(gold_swc_tree=gold_swc_treeroot,
258-
test_swc_tree=test_swc_treeroot,
259-
config=config)
260-
print("---------------Result---------------")
261-
print("gole_branch_num = {}, test_branch_num = {}\n"
262-
"true_positive_number = {}\n"
263-
"false_negative_num = {}\n"
264-
"false_positive_num = {}\n"
265-
"matched_mean_distance = {}\n"
266-
"matched_sum_distance = {}\n"
267-
"pt_score = {}\n"
268-
"isolated node number = {}".
269-
format(branch_res["gold_len"], branch_res["test_len"], branch_res["true_pos_num"],
270-
branch_res["false_neg_num"], branch_res["false_pos_num"], branch_res["mean_dis"],
271-
branch_res["tot_dis"], branch_res["pt_cost"], branch_res["iso_node_num"]))
272-
print("----------------End-----------------")
273-
if output_dest and os.path.exists(output_dest):
274-
swc_save(test_swc_treeroot, os.path.join(output_dest,
275-
"branch_metric",
276-
"{}{}".format(gold_file_name[:-4], "_test.swc")))
277-
swc_save(gold_swc_treeroot, os.path.join(output_dest,
278-
"branch_metric",
279-
"{}{}".format(gold_file_name[:-4], "_gold.swc")))
280-
if metric == "link_metric" or metric == "LM":
281-
link_res = link_metric(test_swc_tree=test_swc_treeroot,
282-
gold_swc_tree=gold_swc_treeroot,
283-
config=config)
284-
print("---------------Result---------------")
285-
print("edge_loss = {}\n"
286-
"tree_dis_loss = {}\n".format(link_res["edge_loss"], link_res["tree_dis_loss"]))
287-
print("---------------End---------------")
238+
if output_dir:
239+
read_json.save_json(data=result,
240+
json_file_path=os.path.join(output_dir, file_name))
241+
242+
243+
# command program
244+
def run():
245+
abs_dir = os.path.abspath("")
246+
init(abs_dir)
247+
248+
try:
249+
args = read_parameters()
250+
except:
251+
raise Exception("[Error: ] Error in reading parameters")
252+
gold_swc_tree, test_swc_trees, metric, output_dir, detail_dir, config, is_debug = set_configs(abs_dir, args)
253+
254+
for test_swc_tree in test_swc_trees:
255+
excute_metric(metric=metric, gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree,
256+
config=config, detail_dir=detail_dir, output_dir=output_dir)
257+
if metric in ["length_metric", "diadem_metric"]:
258+
excute_metric(metric=metric, gold_swc_tree=test_swc_tree, test_swc_tree=gold_swc_tree,
259+
config=config, detail_dir=detail_dir, output_dir=output_dir, file_name_extra="reverse")
288260

289261

290262
if __name__ == "__main__":
@@ -306,4 +278,4 @@ def run(DEBUG=True):
306278

307279
# pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric
308280

309-
# pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric
281+
# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/ --metric branch_metric --detail ./output

pyneval/io/read_swc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44

55
# if path is a fold
66
def read_swc_trees(swc_file_paths, tree_name_dict=None):
7+
"""
8+
Read a swc tree or recursively read all the swc trees in a fold
9+
Args:
10+
swc_file_paths(string): path to read swc
11+
tree_name_dict(dict): a map for swc tree and its file name
12+
key(SwcTree): SwcTree object
13+
value(string): name of the swc tree
14+
Output:
15+
swc_tree_list(list): a list shaped 1*n, containing all the swc tree in the path
16+
"""
717
swc_tree_list = []
818
if os.path.isfile(swc_file_paths):
919
if not (swc_file_paths[-4:] == ".swc" or swc_file_paths[-4:] == ".SWC"):

0 commit comments

Comments
 (0)