Skip to content

Commit a1faee0

Browse files
anna-grimanna-grim
andauthored
Merge detection (#47)
* refactor: merge detection * upd: change threshold for merge detection * lint: black --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent d96814a commit a1faee0

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def init_xyz_to_swc_node(self):
108108
if xyz in self.xyz_to_swc_node.keys():
109109
self.xyz_to_swc_node[xyz][swc_id] = i
110110
else:
111-
self.xyz_to_swc_node[xyz] = {swc_id: i}
111+
self.xyz_to_swc_node[xyz] = {swc_id: i}
112112

113113
def init_kdtree(self):
114114
xyz_list = []
@@ -437,8 +437,9 @@ def detect_merges(self):
437437
valid_1 = label in self.target_to_pred[target_id_1]
438438
valid_2 = label in self.target_to_pred[target_id_2]
439439
if valid_1 and valid_2:
440-
sites, d = self.localize(target_id_1, target_id_2, label)
441-
xyz = utils.get_midpoint(sites[0], sites[1])
440+
sites, d = self.localize(
441+
target_id_1, target_id_2, label
442+
)
442443
if d < 30 and self.write_to_swc:
443444
# Process merge
444445
self.save_swc(sites[0], sites[1], "merge")
@@ -480,7 +481,7 @@ def set_target_to_pred(self):
480481
for target_id, values in hit_target_ids.items():
481482
if len(values) > 16:
482483
self.target_to_pred[target_id].add(int(pred_id))
483-
484+
484485
def localize(self, swc_id_1, swc_id_2, label):
485486
# Get merged nodes
486487
merged_1 = self.label_to_node[swc_id_1][label]
@@ -615,7 +616,7 @@ def compile_results(self):
615616
def generate_full_results(self):
616617
"""
617618
Generates a report by creating a list of the results for each metric.
618-
Each item in this list corresponds to a graph in "self.labeled_target_graphs"
619+
Each item in this list corresponds to a graph in labeled_target_graphs
619620
and this list is ordered with respect to "swc_ids".
620621
621622
Parameters
@@ -628,7 +629,7 @@ def generate_full_results(self):
628629
Specifies the ordering of results for each value in "stats".
629630
stats : dict
630631
Dictionary where the keys are metrics and values are the result of
631-
computing that metric for each graph in "self.labeled_target_graphs".
632+
computing that metric for each graph in labeled_target_graphs.
632633
633634
"""
634635
swc_ids = list(self.labeled_target_graphs.keys())

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
import networkx as nx
13+
1314
from segmentation_skeleton_metrics import graph_utils as gutils
1415
from segmentation_skeleton_metrics import utils
1516

src/segmentation_skeleton_metrics/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def append_dict_value(my_dict, key, value):
315315
my_dict[key] = [value]
316316
return my_dict
317317

318+
318319
def find_best(my_dict, keys):
319320
best_key = None
320321
best_vote_cnt = 0

0 commit comments

Comments
 (0)