@@ -195,6 +195,9 @@ def __init__(
195195 self .encode_path_length = encode_path_length
196196 self .edge_position_vars = {}
197197
198+ self .edges_set_to_zero = {}
199+ self .edges_set_to_one = {}
200+
198201 self .solver_options = solver_options
199202 if self .solver_options is None :
200203 self .solver_options = {}
@@ -288,6 +291,8 @@ def create_solver_and_paths(self):
288291
289292 self ._encode_paths ()
290293
294+ self ._apply_safety_optimizations_fix_zero_edges ()
295+
291296 def _encode_paths (self ):
292297
293298 # Encodes the paths in the graph by creating variables for edges and subpaths.
@@ -447,49 +452,108 @@ def _encode_paths(self):
447452 name = f"path_length_constr_i={ i } "
448453 )
449454
450- ########################################
451- # #
452- # Fixing variables based on safe lists #
453- # #
454- ########################################
455+ def _apply_safety_optimizations (self ):
455456
456457 if self .safe_lists is not None :
457- paths_to_fix = self ._get_paths_to_fix_from_safe_lists ()
458-
459- if not self .optimize_with_safety_as_subpath_constraints :
460- # iterating over safe lists
461- for i in range (min (len (paths_to_fix ), self .k )):
462- # print("Fixing variables for safe list #", i)
463- # iterate over the edges in the safe list to fix variables to 1
464- for u , v in paths_to_fix [i ]:
465- self .solver .add_constraint (
466- self .edge_vars [(u , v , i )] == 1 ,
467- name = f"safe_list_u={ u } _v={ v } _i={ i } " ,
468- )
458+ self .paths_to_fix = self ._get_paths_to_fix_from_safe_lists ()
459+
460+ if not self .optimize_with_safety_as_subpath_constraints :
461+ # iterating over safe lists
462+ for i in range (min (len (self .paths_to_fix ), self .k )):
463+ # print("Fixing variables for safe list #", i)
464+ # iterate over the edges in the safe list to fix variables to 1
465+ for u , v in self .paths_to_fix [i ]:
466+ self .solver .add_constraint (
467+ self .edge_vars [(u , v , i )] == 1 ,
468+ name = f"safe_list_u={ u } _v={ v } _i={ i } " ,
469+ )
470+ self .edges_set_to_one [(u , v , i )] = True
471+
472+ self ._apply_safety_optimizations_fix_zero_edges ()
473+
474+ def _apply_safety_optimizations_fix_zero_edges (self ):
475+ """
476+ Prune layer-edge variables to zero using safe-walk reachability while
477+ preserving edges that can be part of the walk or its connectors.
478+
479+ For each walk i in `walks_to_fix` we build a protection set of edges that
480+ must not be fixed to 0 for layer i:
481+ 1) Protect all edges that appear in the walk itself.
482+ 2) Whole-walk reachability: let first_node be the first node of the walk
483+ and last_node the last node. Protect any edge (u,v) such that
484+ - u is reachable (forward) from last_node, OR
485+ - v can reach (backward) first_node.
486+ 3) Gap-bridging between consecutive edges: for every pair of consecutive
487+ edges whose endpoints do not match (a gap), let
488+ - current_last = end node of the first edge, and
489+ - current_start = start node of the next edge.
490+ Protect any edge (u,v) such that
491+ - u is reachable (forward) from current_last, AND
492+ - v can reach (backward) current_start.
493+
494+ All remaining edges (u,v) not in the protection set are fixed to 0 in
495+ layer i.
496+
497+ Notes:
498+ - Requires `self.paths_to_fix` already computed and `self.edge_vars` created.
499+ """
500+ if not hasattr (self , "paths_to_fix" ) or self .paths_to_fix is None :
501+ return
502+
503+ fixed_zero_count = 0
504+ # Ensure we don't go beyond k layers
505+ for i in range (min (len (self .paths_to_fix ), self .k )):
506+ path = self .paths_to_fix [i ]
507+ if not path or len (path ) == 0 :
508+ continue
509+
510+ # Build the set of edges that should NOT be fixed to 0 for this layer i
511+ # Start by protecting all edges in the path itself
512+ protected_edges = set ((u , v ) for (u , v ) in path if self .G .has_edge (u , v ))
513+
514+ # Also protect edges that are reachable from the last node of the path
515+ # or that can reach the first node of the path
516+ first_node = path [0 ][0 ]
517+ last_node = path [- 1 ][1 ]
518+ for (u , v ) in self .G .edges :
519+ if (u in self .G .reachable_nodes_from [last_node ]) or (v in self .G .nodes_reaching (first_node )):
520+ protected_edges .add ((u , v ))
521+
522+ # Collect pairs of non-contiguous consecutive edges (gaps)
523+ gap_pairs = []
524+ for idx in range (len (path ) - 1 ):
525+ end_prev = path [idx ][1 ]
526+ start_next = path [idx + 1 ][0 ]
527+ # We consider all consecutive edges as gap pairs, because there could be a cycle
528+ # formed between them (this is not the case in DAGs)
529+ if end_prev != start_next :
530+ gap_pairs .append ((end_prev , start_next ))
531+
532+ # For each gap, add edges that can lie on some path bridging the gap
533+ for (current_last , current_start ) in gap_pairs :
534+ for (u , v ) in self .G .edges :
535+ if (u in self .G .nodes_reachable (current_last )) and (v in self .G .nodes_reaching (current_start )):
536+ # if (u in reachable_from_last) and (v in can_reach_start):
537+ protected_edges .add ((u , v ))
538+
539+ # Now fix every other edge to 0 for this layer i
540+ for (u , v ) in self .G .edges :
541+ if (u , v ) in protected_edges :
542+ continue
543+ # Queue zero-fix for batch bounds update
544+ # self.solver.queue_fix_variable(self.edge_vars[(u, v, i)], int(0))
545+ self .solver .add_constraint (
546+ self .edge_vars [(u , v , i )] == 0 ,
547+ name = f"i={ i } _u={ u } _v={ v } _fix0" ,
548+ )
549+ self .edges_set_to_zero [(u , v , i )] = True
550+ fixed_zero_count += 1
551+
552+ if fixed_zero_count :
553+ # Accumulate into solve statistics
554+ self .solve_statistics ["edge_variables=0" ] = self .solve_statistics .get ("edge_variables=0" , 0 ) + fixed_zero_count
555+ utils .logger .debug (f"{ __name__ } : Fixed { fixed_zero_count } edge variables to 0 via reachability pruning." )
469556
470- if self .optimize_with_safe_zero_edges :
471- # get the endpoints of the longest safe path in the sequence
472- first_node , last_node = (
473- safetypathcovers .get_endpoints_of_longest_safe_path_in (paths_to_fix [i ])
474- )
475- # get the reachable nodes from the last node
476- reachable_nodes = self .G .reachable_nodes_from [last_node ]
477- # get the backwards reachable nodes from the first node
478- reachable_nodes_reverse = self .G .reachable_nodes_rev_from [first_node ]
479- # get the edges in the path
480- path_edges = set ((u , v ) for (u , v ) in paths_to_fix [i ])
481-
482- for u , v in self .G .base_graph .edges ():
483- if (
484- (u , v ) not in path_edges
485- and u not in reachable_nodes
486- and v not in reachable_nodes_reverse
487- ):
488- # print(f"Adding zero constraint for edge ({u}, {v}) in path {i}")
489- self .solver .add_constraint (
490- self .edge_vars [(u , v , i )] == 0 ,
491- name = f"safe_list_zero_edge_u={ u } _v={ v } _i={ i } " ,
492- )
493557
494558
495559 def _get_paths_to_fix_from_safe_lists (self ) -> list :
0 commit comments