2020from segmentation_skeleton_metrics import split_detection , swc_utils , utils
2121from segmentation_skeleton_metrics .swc_utils import save , to_graph
2222
23- INTERSECTION_THRESHOLD = 10
24- MERGE_DIST_THRESHOLD = 40
23+ CLOSE_DIST_THRESHOLD = 5
24+ INTERSECTION_THRESHOLD = 16
25+ MERGE_DIST_THRESHOLD = 30
2526
2627
2728class SkeletonMetric :
@@ -47,7 +48,7 @@ def __init__(
4748 anisotropy = [1.0 , 1.0 , 1.0 ],
4849 black_holes_xyz_id = None ,
4950 black_hole_radius = 24 ,
50- equivalent_ids = None ,
51+ connections_path = None ,
5152 ignore_boundary_mistakes = False ,
5253 output_dir = None ,
5354 valid_size_threshold = 25 ,
@@ -75,8 +76,9 @@ def __init__(
7576 ...
7677 black_hole_radius : float, optional
7778 ...
78- equivalent_ids : ...
79- ...
79+ connections_path : list[tuple]
80+ Path to a txt file containing pairs of swc ids from the prediction
81+ that were predicted to be connected.
8082 ignore_boundary_mistakes : bool, optional
8183 Indication of whether to ignore mistakes near boundary of bounding
8284 box. The default is False.
@@ -105,21 +107,35 @@ def __init__(
105107 self .init_black_holes (black_holes_xyz_id )
106108 self .black_hole_radius = black_hole_radius
107109
108- # Build Graphs
110+ # Labels
109111 self .label_mask = pred_labels
110112 self .valid_labels = swc_utils .parse (
111- pred_swc_paths , valid_size_threshold , anisotropy = anisotropy
113+ pred_swc_paths , valid_size_threshold , anisotropy
112114 )
115+ self .init_equiv_labels (connections_path )
113116
117+ # Build Graphs
114118 self .target_graphs = self .init_graphs (target_swc_paths , anisotropy )
115- self .labeled_target_graphs = self . init_labeled_target_graphs ()
119+ self .init_labeled_target_graphs ()
116120
117121 # Build kdtree
118122 self .init_xyz_to_id_node ()
119123 self .init_kdtree ()
120- self .rm_spurious_intersections ()
121124
122125 # -- Initialize and Label Graphs --
126+ def init_equiv_labels (self , path ):
127+ if path :
128+ self .equiv_labels_map = utils .equiv_class_mappings (
129+ path , self .valid_labels
130+ )
131+ valid_labels = dict ()
132+ for label , values in self .valid_labels .items ():
133+ equiv_label = self .equiv_labels_map [label ]
134+ valid_labels [equiv_label ] = values
135+ self .valid_labels = valid_labels
136+ else :
137+ self .equiv_labels_map = None
138+
123139 def init_graphs (self , paths , anisotropy ):
124140 """
125141 Initializes "self.target_graphs" by iterating over "paths" which
@@ -162,17 +178,16 @@ def init_labeled_target_graphs(self):
162178 """
163179 print ("Labelling Target Graphs..." )
164180 t0 = time ()
165- labeled_target_graphs = dict ()
181+ self . labeled_target_graphs = dict ()
166182 self .id_to_label_nodes = dict () # {target_id: {label: nodes}}
167183 for cnt , (target_id , graph ) in enumerate (self .target_graphs .items ()):
168184 utils .progress_bar (cnt + 1 , len (self .target_graphs ))
169185 labeled_target_graph , id_to_label_nodes = self .label_graph (graph )
170- labeled_target_graphs [target_id ] = labeled_target_graph
186+ self . labeled_target_graphs [target_id ] = labeled_target_graph
171187 self .id_to_label_nodes [target_id ] = id_to_label_nodes
172188
173189 t , unit = utils .time_writer (time () - t0 )
174190 print (f"\n Runtime: { round (t , 2 )} { unit } \n " )
175- return labeled_target_graphs
176191
177192 def label_graph (self , target_graph ):
178193 """
@@ -229,15 +244,17 @@ def get_label(self, img_coord, return_node=False):
229244 if self .in_black_hole (img_coord ):
230245 label = - 1
231246 else :
232- label = self .__read_label (img_coord )
247+ label = self .read_label (img_coord )
233248
234- # Validate label
249+ # Adjust label
250+ label = self .equivalent_label (label )
251+ label = self .validate (label )
235252 if return_node :
236- return return_node , self . is_valid ( label )
253+ return return_node , label
237254 else :
238- return self . is_valid ( label )
255+ return label
239256
240- def __read_label (self , coord ):
257+ def read_label (self , coord ):
241258 """
242259 Gets label at image coordinates "xyz".
243260
@@ -252,12 +269,23 @@ def __read_label(self, coord):
252269 Label at image coordinates "xyz".
253270
254271 """
272+ # Read image label
255273 if type (self .label_mask ) == ts .TensorStore :
256274 return int (self .label_mask [coord ].read ().result ())
257275 else :
258276 return self .label_mask [coord ]
259277
260- def is_valid (self , label ):
278+ def equivalent_label (self , label ):
279+ # Equivalent label
280+ if self .equiv_labels_map :
281+ if label in self .equiv_labels_map .keys ():
282+ return self .equiv_labels_map [label ]
283+ else :
284+ return 0
285+ else :
286+ return label
287+
288+ def validate (self , label ):
261289 """
262290 Validates label by checking whether it is contained in
263291 "self.valid_labels".
@@ -290,35 +318,6 @@ def init_xyz_to_id_node(self):
290318 else :
291319 self .xyz_to_id_node [xyz ] = {target_id : i }
292320
293- def rm_spurious_intersections (self ):
294- for label in [label for label in self .get_all_labels () if label > 0 ]:
295- # Compute label intersect target_graphs
296- hit_target_ids = dict ()
297- multi_hits = set ()
298- for xyz in self .get_pred_coords (label ):
299- hat_xyz , d = self .get_projection (xyz )
300- if d < 5 :
301- hits = list (self .xyz_to_id_node [hat_xyz ].keys ())
302- if len (hits ) > 1 :
303- multi_hits .add (hat_xyz )
304- else :
305- hat_i = self .xyz_to_id_node [hat_xyz ][hits [0 ]]
306- hit_target_ids = utils .append_dict_value (
307- hit_target_ids , hits [0 ], hat_i
308- )
309- hit_target_ids = utils .resolve (
310- multi_hits , hit_target_ids , self .xyz_to_id_node
311- )
312-
313- # Remove spurious intersections
314- for target_id in self .target_graphs .keys ():
315- if target_id in hit_target_ids .keys ():
316- n_hits = len (hit_target_ids [target_id ])
317- if n_hits < INTERSECTION_THRESHOLD :
318- self .zero_nodes (target_id , label )
319- elif label in self .id_to_label_nodes [target_id ]:
320- self .zero_nodes (target_id , label )
321-
322321 def get_pred_coords (self , label ):
323322 if label in self .valid_labels .keys ():
324323 return self .valid_labels [label ]
@@ -347,14 +346,13 @@ def init_black_holes(self, black_holes):
347346 self .black_holes = None
348347 self .black_hole_labels = set ()
349348
350- def in_black_hole (self , xyz , print_nn = False ):
351- # Check whether black_holes exists
352- if self .black_holes is None :
353- return False
354- else :
349+ def in_black_hole (self , xyz ):
350+ if self .black_holes :
355351 radius = self .black_hole_radius
356352 pts = self .black_holes .query_ball_point (xyz , radius )
357353 return True if len (pts ) > 0 else False
354+ else :
355+ return False
358356
359357 # -- Evaluation --
360358 def compute_metrics (self ):
@@ -487,10 +485,11 @@ def detect_merges(self):
487485 None
488486
489487 """
490- # Initilize counts
488+ # Initilizations
491489 self .merge_cnts = self .init_counter ()
492490 self .merged_cnts = self .init_counter ()
493491 self .merged_percents = self .init_counter ()
492+ # self.rm_spurious_intersections()
494493
495494 # Run detection
496495 t0 = time ()
@@ -512,7 +511,9 @@ def detect_merges(self):
512511 if merge not in detected_merges :
513512 detected_merges .add (merge )
514513 if self .save :
515- site , d = self .locate (target_id_1 , target_id_2 , label )
514+ site , d = self .localize_site (
515+ target_id_1 , target_id_2 , label
516+ )
516517 if d < MERGE_DIST_THRESHOLD :
517518 self .save_swc (site [0 ], site [1 ], "merge" )
518519
@@ -528,7 +529,36 @@ def detect_merges(self):
528529 t , unit = utils .time_writer (time () - t0 )
529530 print (f"\n Runtime: { round (t , 2 )} { unit } \n " )
530531
531- def locate (self , target_id_1 , target_id_2 , label ):
532+ def rm_spurious_intersections (self ):
533+ for label in [label for label in self .get_all_labels () if label > 0 ]:
534+ # Compute label intersect target_graphs
535+ hit_target_ids = dict ()
536+ multi_hits = set ()
537+ for xyz in self .self .get_pred_coords (label ):
538+ hat_xyz , d = self .get_projection (xyz )
539+ if d < CLOSE_DIST_THRESHOLD :
540+ hits = list (self .xyz_to_id_node [hat_xyz ].keys ())
541+ if len (hits ) > 1 :
542+ multi_hits .add (hat_xyz )
543+ else :
544+ hat_i = self .xyz_to_id_node [hat_xyz ][hits [0 ]]
545+ hit_target_ids = utils .append_dict_value (
546+ hit_target_ids , hits [0 ], hat_i
547+ )
548+ hit_target_ids = utils .resolve (
549+ multi_hits , hit_target_ids , self .xyz_to_id_node
550+ )
551+
552+ # Remove spurious intersections
553+ for target_id in self .target_graphs .keys ():
554+ if target_id in hit_target_ids .keys ():
555+ n_hits = len (hit_target_ids [target_id ])
556+ if n_hits < INTERSECTION_THRESHOLD :
557+ self .zero_nodes (target_id , label )
558+ elif label in self .id_to_label_nodes [target_id ]:
559+ self .zero_nodes (target_id , label )
560+
561+ def localize_site (self , target_id_1 , target_id_2 , label ):
532562 # Get merged nodes
533563 merged_1 = self .id_to_label_nodes [target_id_1 ][label ]
534564 merged_2 = self .id_to_label_nodes [target_id_2 ][label ]
@@ -548,6 +578,20 @@ def locate(self, target_id_1, target_id_2, label):
548578 return xyz_pair , min_dist
549579
550580 def near_bdd (self , xyz ):
581+ """
582+ Determines whether "xyz" is near the boundary of the image.
583+
584+ Parameters
585+ ----------
586+ xyz : numpy.ndarray
587+ xyz coordinate to be checked
588+
589+ Returns
590+ -------
591+ near_bdd_bool : bool
592+ Indication of whether "xyz" is near the boundary of the image.
593+
594+ """
551595 near_bdd_bool = False
552596 if self .ignore_boundary_mistakes :
553597 mask_shape = self .label_mask .shape
0 commit comments