Skip to content

Commit ec7f857

Browse files
anna-grimanna-grim
andauthored
feat: equivalent labels defined wrt new connections: (#55)
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 326ce79 commit ec7f857

File tree

5 files changed

+299
-62
lines changed

5 files changed

+299
-62
lines changed

src/segmentation_skeleton_metrics/graph_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,21 @@ def sample_leaf(graph):
233233
"""
234234
leafs = [i for i in graph.nodes if graph.degree[i] == 1]
235235
return sample(leafs, 1)[0]
236+
237+
238+
def sample_node(graph):
239+
"""
240+
Samples a node from "graph".
241+
242+
Parameters
243+
----------
244+
graph : networkx.Graph
245+
Graph to be sampled from.
246+
247+
Returns
248+
-------
249+
int
250+
Node.
251+
252+
"""
253+
return sample(list(graph.nodes), 1)[0]

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 99 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from segmentation_skeleton_metrics import split_detection, swc_utils, utils
2121
from segmentation_skeleton_metrics.swc_utils import save, to_graph
2222

23-
INTERSECTION_THRESHOLD = 10
24-
MERGE_DIST_THRESHOLD = 40
23+
CLOSE_DIST_THRESHOLD = 5
24+
INTERSECTION_THRESHOLD = 16
25+
MERGE_DIST_THRESHOLD = 30
2526

2627

2728
class SkeletonMetric:
@@ -47,7 +48,7 @@ def __init__(
4748
anisotropy=[1.0, 1.0, 1.0],
4849
black_holes_xyz_id=None,
4950
black_hole_radius=24,
50-
equivalent_ids=None,
51+
connections_path=None,
5152
ignore_boundary_mistakes=False,
5253
output_dir=None,
5354
valid_size_threshold=25,
@@ -75,8 +76,9 @@ def __init__(
7576
...
7677
black_hole_radius : float, optional
7778
...
78-
equivalent_ids : ...
79-
...
79+
connections_path : list[tuple]
80+
Path to a txt file containing pairs of swc ids from the prediction
81+
that were predicted to be connected.
8082
ignore_boundary_mistakes : bool, optional
8183
Indication of whether to ignore mistakes near boundary of bounding
8284
box. The default is False.
@@ -105,21 +107,35 @@ def __init__(
105107
self.init_black_holes(black_holes_xyz_id)
106108
self.black_hole_radius = black_hole_radius
107109

108-
# Build Graphs
110+
# Labels
109111
self.label_mask = pred_labels
110112
self.valid_labels = swc_utils.parse(
111-
pred_swc_paths, valid_size_threshold, anisotropy=anisotropy
113+
pred_swc_paths, valid_size_threshold, anisotropy
112114
)
115+
self.init_equiv_labels(connections_path)
113116

117+
# Build Graphs
114118
self.target_graphs = self.init_graphs(target_swc_paths, anisotropy)
115-
self.labeled_target_graphs = self.init_labeled_target_graphs()
119+
self.init_labeled_target_graphs()
116120

117121
# Build kdtree
118122
self.init_xyz_to_id_node()
119123
self.init_kdtree()
120-
self.rm_spurious_intersections()
121124

122125
# -- Initialize and Label Graphs --
126+
def init_equiv_labels(self, path):
127+
if path:
128+
self.equiv_labels_map = utils.equiv_class_mappings(
129+
path, self.valid_labels
130+
)
131+
valid_labels = dict()
132+
for label, values in self.valid_labels.items():
133+
equiv_label = self.equiv_labels_map[label]
134+
valid_labels[equiv_label] = values
135+
self.valid_labels = valid_labels
136+
else:
137+
self.equiv_labels_map = None
138+
123139
def init_graphs(self, paths, anisotropy):
124140
"""
125141
Initializes "self.target_graphs" by iterating over "paths" which
@@ -162,17 +178,16 @@ def init_labeled_target_graphs(self):
162178
"""
163179
print("Labelling Target Graphs...")
164180
t0 = time()
165-
labeled_target_graphs = dict()
181+
self.labeled_target_graphs = dict()
166182
self.id_to_label_nodes = dict() # {target_id: {label: nodes}}
167183
for cnt, (target_id, graph) in enumerate(self.target_graphs.items()):
168184
utils.progress_bar(cnt + 1, len(self.target_graphs))
169185
labeled_target_graph, id_to_label_nodes = self.label_graph(graph)
170-
labeled_target_graphs[target_id] = labeled_target_graph
186+
self.labeled_target_graphs[target_id] = labeled_target_graph
171187
self.id_to_label_nodes[target_id] = id_to_label_nodes
172188

173189
t, unit = utils.time_writer(time() - t0)
174190
print(f"\nRuntime: {round(t, 2)} {unit}\n")
175-
return labeled_target_graphs
176191

177192
def label_graph(self, target_graph):
178193
"""
@@ -229,15 +244,17 @@ def get_label(self, img_coord, return_node=False):
229244
if self.in_black_hole(img_coord):
230245
label = -1
231246
else:
232-
label = self.__read_label(img_coord)
247+
label = self.read_label(img_coord)
233248

234-
# Validate label
249+
# Adjust label
250+
label = self.equivalent_label(label)
251+
label = self.validate(label)
235252
if return_node:
236-
return return_node, self.is_valid(label)
253+
return return_node, label
237254
else:
238-
return self.is_valid(label)
255+
return label
239256

240-
def __read_label(self, coord):
257+
def read_label(self, coord):
241258
"""
242259
Gets label at image coordinates "xyz".
243260
@@ -252,12 +269,23 @@ def __read_label(self, coord):
252269
Label at image coordinates "xyz".
253270
254271
"""
272+
# Read image label
255273
if type(self.label_mask) == ts.TensorStore:
256274
return int(self.label_mask[coord].read().result())
257275
else:
258276
return self.label_mask[coord]
259277

260-
def is_valid(self, label):
278+
def equivalent_label(self, label):
279+
# Equivalent label
280+
if self.equiv_labels_map:
281+
if label in self.equiv_labels_map.keys():
282+
return self.equiv_labels_map[label]
283+
else:
284+
return 0
285+
else:
286+
return label
287+
288+
def validate(self, label):
261289
"""
262290
Validates label by checking whether it is contained in
263291
"self.valid_labels".
@@ -290,35 +318,6 @@ def init_xyz_to_id_node(self):
290318
else:
291319
self.xyz_to_id_node[xyz] = {target_id: i}
292320

293-
def rm_spurious_intersections(self):
294-
for label in [label for label in self.get_all_labels() if label > 0]:
295-
# Compute label intersect target_graphs
296-
hit_target_ids = dict()
297-
multi_hits = set()
298-
for xyz in self.get_pred_coords(label):
299-
hat_xyz, d = self.get_projection(xyz)
300-
if d < 5:
301-
hits = list(self.xyz_to_id_node[hat_xyz].keys())
302-
if len(hits) > 1:
303-
multi_hits.add(hat_xyz)
304-
else:
305-
hat_i = self.xyz_to_id_node[hat_xyz][hits[0]]
306-
hit_target_ids = utils.append_dict_value(
307-
hit_target_ids, hits[0], hat_i
308-
)
309-
hit_target_ids = utils.resolve(
310-
multi_hits, hit_target_ids, self.xyz_to_id_node
311-
)
312-
313-
# Remove spurious intersections
314-
for target_id in self.target_graphs.keys():
315-
if target_id in hit_target_ids.keys():
316-
n_hits = len(hit_target_ids[target_id])
317-
if n_hits < INTERSECTION_THRESHOLD:
318-
self.zero_nodes(target_id, label)
319-
elif label in self.id_to_label_nodes[target_id]:
320-
self.zero_nodes(target_id, label)
321-
322321
def get_pred_coords(self, label):
323322
if label in self.valid_labels.keys():
324323
return self.valid_labels[label]
@@ -347,14 +346,13 @@ def init_black_holes(self, black_holes):
347346
self.black_holes = None
348347
self.black_hole_labels = set()
349348

350-
def in_black_hole(self, xyz, print_nn=False):
351-
# Check whether black_holes exists
352-
if self.black_holes is None:
353-
return False
354-
else:
349+
def in_black_hole(self, xyz):
350+
if self.black_holes:
355351
radius = self.black_hole_radius
356352
pts = self.black_holes.query_ball_point(xyz, radius)
357353
return True if len(pts) > 0 else False
354+
else:
355+
return False
358356

359357
# -- Evaluation --
360358
def compute_metrics(self):
@@ -487,10 +485,11 @@ def detect_merges(self):
487485
None
488486
489487
"""
490-
# Initilize counts
488+
# Initilizations
491489
self.merge_cnts = self.init_counter()
492490
self.merged_cnts = self.init_counter()
493491
self.merged_percents = self.init_counter()
492+
# self.rm_spurious_intersections()
494493

495494
# Run detection
496495
t0 = time()
@@ -512,7 +511,9 @@ def detect_merges(self):
512511
if merge not in detected_merges:
513512
detected_merges.add(merge)
514513
if self.save:
515-
site, d = self.locate(target_id_1, target_id_2, label)
514+
site, d = self.localize_site(
515+
target_id_1, target_id_2, label
516+
)
516517
if d < MERGE_DIST_THRESHOLD:
517518
self.save_swc(site[0], site[1], "merge")
518519

@@ -528,7 +529,36 @@ def detect_merges(self):
528529
t, unit = utils.time_writer(time() - t0)
529530
print(f"\nRuntime: {round(t, 2)} {unit}\n")
530531

531-
def locate(self, target_id_1, target_id_2, label):
532+
def rm_spurious_intersections(self):
533+
for label in [label for label in self.get_all_labels() if label > 0]:
534+
# Compute label intersect target_graphs
535+
hit_target_ids = dict()
536+
multi_hits = set()
537+
for xyz in self.self.get_pred_coords(label):
538+
hat_xyz, d = self.get_projection(xyz)
539+
if d < CLOSE_DIST_THRESHOLD:
540+
hits = list(self.xyz_to_id_node[hat_xyz].keys())
541+
if len(hits) > 1:
542+
multi_hits.add(hat_xyz)
543+
else:
544+
hat_i = self.xyz_to_id_node[hat_xyz][hits[0]]
545+
hit_target_ids = utils.append_dict_value(
546+
hit_target_ids, hits[0], hat_i
547+
)
548+
hit_target_ids = utils.resolve(
549+
multi_hits, hit_target_ids, self.xyz_to_id_node
550+
)
551+
552+
# Remove spurious intersections
553+
for target_id in self.target_graphs.keys():
554+
if target_id in hit_target_ids.keys():
555+
n_hits = len(hit_target_ids[target_id])
556+
if n_hits < INTERSECTION_THRESHOLD:
557+
self.zero_nodes(target_id, label)
558+
elif label in self.id_to_label_nodes[target_id]:
559+
self.zero_nodes(target_id, label)
560+
561+
def localize_site(self, target_id_1, target_id_2, label):
532562
# Get merged nodes
533563
merged_1 = self.id_to_label_nodes[target_id_1][label]
534564
merged_2 = self.id_to_label_nodes[target_id_2][label]
@@ -548,6 +578,20 @@ def locate(self, target_id_1, target_id_2, label):
548578
return xyz_pair, min_dist
549579

550580
def near_bdd(self, xyz):
581+
"""
582+
Determines whether "xyz" is near the boundary of the image.
583+
584+
Parameters
585+
----------
586+
xyz : numpy.ndarray
587+
xyz coordinate to be checked
588+
589+
Returns
590+
-------
591+
near_bdd_bool : bool
592+
Indication of whether "xyz" is near the boundary of the image.
593+
594+
"""
551595
near_bdd_bool = False
552596
if self.ignore_boundary_mistakes:
553597
mask_shape = self.label_mask.shape

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@
1616

1717

1818
def run(target_graph, labeled_graph):
19+
"""
20+
Detected splits in a predicted segmentation.
21+
22+
Parameters
23+
----------
24+
target_graph : networkx.Graph
25+
Graph built from a ground truth swc file.
26+
labeled_graph : networkx.Graph
27+
Labeled graph built from a ground truth swc file, where each node has
28+
an attribute called 'label'.
29+
30+
Returns
31+
-------
32+
labeled_graph : networkx.Graph
33+
Labeled graph with omit and split edges removed.
34+
35+
"""
1936
r = gutils.sample_leaf(target_graph)
2037
dfs_edges = list(nx.dfs_edges(target_graph, source=r))
2138
while len(dfs_edges) > 0:
@@ -43,9 +60,10 @@ def is_zero_misalignment(target_graph, labeled_graph, dfs_edges, nb, root):
4360
Parameters
4461
----------
4562
target_graph : networkx.Graph
46-
...
63+
Graph built from a ground truth swc file.
4764
labeled_graph : networkx.Graph
48-
...
65+
Labeled graph built from a ground truth swc file, where each node has
66+
an attribute called 'label'.
4967
dfs_edges : list[tuple]
5068
List of edges to be processed for split detection.
5169
nb : int

0 commit comments

Comments
 (0)