diff --git a/examples/cycle_detection.py b/examples/cycle_detection.py new file mode 100644 index 000000000000..a25dc055638d --- /dev/null +++ b/examples/cycle_detection.py @@ -0,0 +1,192 @@ +"""Example: Max-Weight Simple Cycle Detection. + +=========================================== + +This example demonstrates how to use the cycle detection utilities in PyTorch +Geometric to find maximum weight cycles in directed graphs. + +Use Case: DAG with Reversed Edges +---------------------------------- +A common scenario is starting with a Directed Acyclic Graph (DAG) with +non-negative edge weights, then adding reversed edges with negative weights. +This creates cycles, and we want to find the cycle with the highest total +weight. +""" + +import torch + +from torch_geometric.utils import find_all_cycles, find_max_weight_cycle + +# Example 1: Simple Triangle Cycle +# ================================= +print("Example 1: Simple Triangle Cycle") +print("-" * 50) + +# Create a simple cycle: 0 -> 1 -> 2 -> 0 +edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) +edge_weight = torch.tensor([1.0, 2.0, 3.0]) + +cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + +print(f"Edge index:\n{edge_index}") +print(f"Edge weights: {edge_weight.tolist()}") +print(f"Found cycle: {cycle.tolist()}") +print(f"Cycle weight: {weight}") +print() + +# Example 2: DAG with Reversed Edges (Original Use Case) +# ======================================================= +print("Example 2: DAG with Reversed Edges") +print("-" * 50) + +# Original DAG: 0 -> 1 -> 2 -> 3 +# Add reversed edges: 1 -> 0, 2 -> 1 +# This creates multiple cycles with different weights + +# Forward edges (positive weights) +forward_edges = torch.tensor([[0, 1, 2], [1, 2, 3]]) +forward_weights = torch.tensor([5.0, 3.0, 2.0]) + +# Reversed edges (negative weights) +reverse_edges = torch.tensor([[1, 2], [0, 1]]) +reverse_weights = torch.tensor([-4.0, -1.0]) + +# Combine into single graph +edge_index = torch.cat([forward_edges, reverse_edges], dim=1) +edge_weight = torch.cat([forward_weights, reverse_weights]) + +print(f"Forward edges: {forward_edges.t().tolist()}") +print(f"Forward weights: {forward_weights.tolist()}") +print(f"Reverse edges: {reverse_edges.t().tolist()}") +print(f"Reverse weights: {reverse_weights.tolist()}") +print() + +# Find the maximum weight cycle +cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + +if cycle is not None: + print(f"Maximum weight cycle: {cycle.tolist()}") + print(f"Cycle weight: {weight}") + + # Reconstruct the cycle path with edge weights + print("\nCycle path:") + for i in range(len(cycle)): + src = cycle[i].item() + dst = cycle[(i + 1) % len(cycle)].item() + + # Find edge weight + edge_mask = (edge_index[0] == src) & (edge_index[1] == dst) + edge_w = edge_weight[edge_mask].item() + + print(f" {src} -> {dst}: weight = {edge_w}") +else: + print("No cycle found") +print() + +# Example 3: Finding All Cycles +# ============================== +print("Example 3: Finding All Cycles") +print("-" * 50) + +# Create a graph with multiple cycles +edge_index = torch.tensor([ + [0, 1, 2, 1, 3, 4], # Sources + [1, 2, 0, 0, 4, 3], # Targets +]) +edge_weight = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + +print(f"Edge index:\n{edge_index}") +print(f"Edge weights: {edge_weight.tolist()}") +print() + +# Find all cycles +cycles, weights = find_all_cycles(edge_index, edge_weight) + +print(f"Found {len(cycles)} cycles:") +for i, (cycle, weight) in enumerate(zip(cycles, weights)): + print(f" Cycle {i + 1}: {cycle.tolist()}, weight = {weight}") +print() + +# Example 4: Filtering Cycles by Minimum Weight +# ============================================== +print("Example 4: Filtering Cycles by Minimum Weight") +print("-" * 50) + +# Only find cycles with weight >= 5.0 +cycles, weights = find_all_cycles(edge_index, edge_weight, min_weight=5.0) + +print("Cycles with weight >= 5.0:") +for i, (cycle, weight) in enumerate(zip(cycles, weights)): + print(f" Cycle {i + 1}: {cycle.tolist()}, weight = {weight}") +print() + +# Example 5: Limiting Cycle Length +# ================================= +print("Example 5: Limiting Cycle Length") +print("-" * 50) + +# Create a longer cycle +edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) +edge_weight = torch.ones(5) + +print("Graph has a 5-node cycle") +print() + +# Search with different max lengths +for max_len in [3, 5, 10]: + cycle, weight = find_max_weight_cycle( + edge_index, + edge_weight, + max_cycle_length=max_len, + ) + + if cycle is not None: + print(f"Max length {max_len}: Found cycle of length {len(cycle)}") + else: + print(f"Max length {max_len}: No cycle found") +print() + +# Example 6: Performance Tuning with top_k_paths +# =============================================== +print("Example 6: Performance Tuning") +print("-" * 50) + +# Create a dense graph +num_nodes = 10 +edges = [] +weights = [] + +# Create a grid-like structure with cycles +for i in range(num_nodes - 1): + edges.append([i, i + 1]) + weights.append(1.0) + if i > 0: + edges.append([i, i - 1]) + weights.append(0.5) + +# Add some long-range connections +edges.extend([[0, 5], [5, 9], [9, 0]]) +weights.extend([2.0, 2.0, 2.0]) + +edge_index = torch.tensor(edges).t() +edge_weight = torch.tensor(weights) + +print(f"Graph with {num_nodes} nodes and {len(edges)} edges") +print() + +# Compare different top_k values +for top_k in [10, 50, 100]: + cycle, weight = find_max_weight_cycle( + edge_index, + edge_weight, + top_k_paths=top_k, + ) + + if cycle is not None: + print(f"top_k={top_k}: Found cycle with weight {weight:.2f}") + else: + print(f"top_k={top_k}: No cycle found") + +print() +print("=" * 50) +print("Examples completed!") diff --git a/test/utils/test_cycle.py b/test/utils/test_cycle.py new file mode 100644 index 000000000000..d18a42e04696 --- /dev/null +++ b/test/utils/test_cycle.py @@ -0,0 +1,220 @@ +import torch + +from torch_geometric.utils import find_all_cycles, find_max_weight_cycle + + +def test_find_max_weight_cycle_simple(): + # Simple triangle cycle: 0 -> 1 -> 2 -> 0 + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + edge_weight = torch.tensor([1.0, 2.0, 3.0]) + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + assert cycle is not None + assert cycle.tolist() == [0, 1, 2] + assert weight == 6.0 + + +def test_find_max_weight_cycle_no_cycle(): + # DAG with no cycles + edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) + edge_weight = torch.tensor([1.0, 2.0, 3.0]) + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + assert cycle is None + assert weight == float('-inf') + + +def test_find_max_weight_cycle_multiple_cycles(): + # Graph with two cycles: + # Cycle 1: 0 -> 1 -> 0 (weight: 1 + 2 = 3) + # Cycle 2: 2 -> 3 -> 2 (weight: 4 + 5 = 9) + edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + edge_weight = torch.tensor([1.0, 2.0, 4.0, 5.0]) + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + assert cycle is not None + assert weight == 9.0 + # Should find the heavier cycle + assert 2 in cycle.tolist() and 3 in cycle.tolist() + + +def test_find_max_weight_cycle_negative_weights(): + # DAG with reversed edges (negative weights) + # Forward: 0 -> 1 (5), 1 -> 2 (3), 2 -> 3 (2) + # Backward: 1 -> 0 (-4), 2 -> 1 (-1) + # Cycle 0 -> 1 -> 0: 5 + (-4) = 1 + # Cycle 0 -> 1 -> 2 -> 1 -> 0: 5 + 3 + (-1) + (-4) = 3 + edge_index = torch.tensor([[0, 1, 2, 1, 2], [1, 2, 3, 0, 1]]) + edge_weight = torch.tensor([5.0, 3.0, 2.0, -4.0, -1.0]) + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + assert cycle is not None + assert weight > 0 # Should find a positive weight cycle + + +def test_find_max_weight_cycle_self_loop(): + # Graph with self-loop + edge_index = torch.tensor([[0, 0, 1], [0, 1, 0]]) + edge_weight = torch.tensor([10.0, 1.0, 2.0]) + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + # Should find the 0 -> 1 -> 0 cycle, not the self-loop + assert cycle is not None + assert len(cycle) > 1 # Not a self-loop + + +def test_find_max_weight_cycle_no_weights(): + # Test with default weights (all 1.0) + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + + cycle, weight = find_max_weight_cycle(edge_index) + + assert cycle is not None + assert cycle.tolist() == [0, 1, 2] + assert weight == 3.0 + + +def test_find_max_weight_cycle_max_length(): + # Test max_cycle_length parameter + # Create a long cycle: 0 -> 1 -> 2 -> 3 -> 4 -> 0 + edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) + edge_weight = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + + # With max_cycle_length=3, should not find the 5-node cycle + cycle, weight = find_max_weight_cycle( + edge_index, + edge_weight, + max_cycle_length=3, + ) + + # Should either find no cycle or a shorter one + if cycle is not None: + assert len(cycle) <= 3 + + +def test_find_max_weight_cycle_top_k(): + # Test top_k_paths parameter with complex graph + edge_index = torch.tensor([ + [0, 0, 1, 1, 2, 2, 3], + [1, 2, 2, 3, 3, 0, 0], + ]) + edge_weight = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) + + # Should work with different top_k values + cycle1, weight1 = find_max_weight_cycle( + edge_index, + edge_weight, + top_k_paths=10, + ) + cycle2, weight2 = find_max_weight_cycle( + edge_index, + edge_weight, + top_k_paths=100, + ) + + # Both should find a cycle + assert cycle1 is not None + assert cycle2 is not None + + +def test_find_all_cycles_simple(): + # Graph with two simple cycles + edge_index = torch.tensor([[0, 1, 2, 1], [1, 2, 0, 0]]) + edge_weight = torch.tensor([1.0, 1.0, 1.0, 2.0]) + + cycles, weights = find_all_cycles(edge_index, edge_weight) + + assert len(cycles) >= 2 # At least two cycles + assert len(cycles) == len(weights) + + # Check that we found the expected cycles + cycle_sets = [set(c.tolist()) for c in cycles] + assert {0, 1, 2} in cycle_sets # Triangle cycle + assert {0, 1} in cycle_sets # Two-node cycle + + +def test_find_all_cycles_min_weight(): + # Test min_weight filtering + edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + edge_weight = torch.tensor([1.0, 2.0, 4.0, 5.0]) + + # Only cycles with weight >= 5 should be returned + cycles, weights = find_all_cycles(edge_index, edge_weight, min_weight=5.0) + + assert all(w >= 5.0 for w in weights) + + +def test_find_all_cycles_no_cycles(): + # DAG with no cycles + edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) + + cycles, weights = find_all_cycles(edge_index) + + assert len(cycles) == 0 + assert len(weights) == 0 + + +def test_find_all_cycles_max_length(): + # Test max_cycle_length parameter + edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) + + cycles, weights = find_all_cycles(edge_index, max_cycle_length=3) + + # All cycles should have length <= 3 + assert all(len(c) <= 3 for c in cycles) + + +def test_cycle_detection_large_graph(): + # Test with a larger graph + num_nodes = 20 + # Create a ring graph with some additional edges + edge_list = [[i, (i + 1) % num_nodes] for i in range(num_nodes)] + # Add some shortcuts + edge_list.extend([[0, 5], [5, 10], [10, 15], [15, 0]]) + + edge_index = torch.tensor(edge_list).t() + edge_weight = torch.ones(edge_index.size(1)) + + cycle, weight = find_max_weight_cycle( + edge_index, + edge_weight, + max_cycle_length=10, + ) + + # Should find some cycle + assert cycle is not None + assert weight > 0 + + +def test_cycle_detection_disconnected_graph(): + # Test with disconnected components + # Component 1: 0 -> 1 -> 0 + # Component 2: 2 -> 3 -> 2 + edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + edge_weight = torch.tensor([1.0, 1.0, 5.0, 5.0]) + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + # Should find the heavier cycle in component 2 + assert cycle is not None + assert weight == 10.0 + + +def test_cycle_detection_cuda(): + if not torch.cuda.is_available(): + return + + # Test on CUDA device + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]).cuda() + edge_weight = torch.tensor([1.0, 2.0, 3.0]).cuda() + + cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + + assert cycle is not None + assert cycle.device.type == 'cuda' + assert weight == 6.0 diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 68f0c7ad00ca..b51b4cd037f3 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -58,6 +58,7 @@ from .ppr import get_ppr from ._train_test_split_edges import train_test_split_edges from .influence import total_influence +from .cycle import find_max_weight_cycle, find_all_cycles __all__ = [ 'scatter', @@ -151,6 +152,8 @@ 'get_ppr', 'train_test_split_edges', 'total_influence', + 'find_max_weight_cycle', + 'find_all_cycles', ] # `structured_negative_sampling_feasible` is a long name and thus destroys the diff --git a/torch_geometric/utils/cycle.py b/torch_geometric/utils/cycle.py new file mode 100644 index 000000000000..656e2460ed73 --- /dev/null +++ b/torch_geometric/utils/cycle.py @@ -0,0 +1,303 @@ +"""Cycle detection utilities for directed graphs.""" +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from torch_geometric.utils import coalesce, sort_edge_index +from torch_geometric.utils.num_nodes import maybe_num_nodes + + +def find_max_weight_cycle( + edge_index: Tensor, + edge_weight: Optional[Tensor] = None, + num_nodes: Optional[int] = None, + max_cycle_length: Optional[int] = None, + top_k_paths: int = 100, +) -> Tuple[Optional[Tensor], float]: + r"""Finds the maximum weight simple cycle in a directed graph using an + iterative message passing approach. + + This function detects cycles by propagating path information through the + graph, where each path tracks its starting node, accumulated weight, and + visited nodes. A cycle is detected when a path returns to its starting + node. + + Args: + edge_index (torch.Tensor): The edge indices in COO format with shape + :obj:`[2, num_edges]`. + edge_weight (torch.Tensor, optional): Edge weights. If :obj:`None`, + all edges are assumed to have weight :obj:`1.0`. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes in the graph. If + :obj:`None`, will be inferred from :obj:`edge_index`. + (default: :obj:`None`) + max_cycle_length (int, optional): Maximum cycle length to search for. + If :obj:`None`, defaults to :obj:`num_nodes`. Larger values + increase computation time. (default: :obj:`None`) + top_k_paths (int, optional): Number of top paths to keep per node + during search to manage memory. Higher values increase accuracy + but use more memory. (default: :obj:`100`) + + Returns: + (torch.Tensor, float): A tuple containing: + + - **cycle_nodes** (*torch.Tensor* or *None*): Node indices forming the + maximum weight cycle, or :obj:`None` if no cycle exists. The cycle + is represented as a 1D tensor where consecutive elements form edges, + and the last element connects back to the first. + - **cycle_weight** (*float*): Total weight of the cycle, or + :obj:`-float('inf')` if no cycle exists. + + Examples: + >>> # Create a simple cycle: 0 -> 1 -> 2 -> 0 + >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + >>> edge_weight = torch.tensor([1.0, 2.0, 3.0]) + >>> cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + >>> print(cycle) + tensor([0, 1, 2]) + >>> print(weight) + 6.0 + + >>> # DAG with reversed edges (negative weights create cycles) + >>> edge_index = torch.tensor([[0, 1, 2, 1, 2], [1, 2, 3, 0, 1]]) + >>> edge_weight = torch.tensor([5.0, 3.0, 2.0, -4.0, -1.0]) + >>> cycle, weight = find_max_weight_cycle(edge_index, edge_weight) + >>> print(cycle) + tensor([0, 1]) + >>> print(weight) + 1.0 + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + if edge_weight is None: + edge_weight = torch.ones( + edge_index.size(1), + dtype=torch.float, + device=edge_index.device, + ) + + if max_cycle_length is None: + max_cycle_length = num_nodes + + # Coalesce edges to handle multi-edges + edge_index, edge_weight = coalesce( + edge_index, + edge_weight, + num_nodes, + num_nodes, + ) + + # Sort edge index for efficient lookup + edge_index, edge_weight = sort_edge_index( + edge_index, + edge_weight, + num_nodes=num_nodes, + ) + + # Initialize path states for each node + # Each path state: (start_node, current_weight, visited_mask, path_nodes) + device = edge_index.device + + # Track best cycle found + best_cycle_weight = float('-inf') + best_cycle_nodes = None + + # For each potential starting node + for start_node in range(num_nodes): + # Use dynamic programming approach: track best paths to each node + # State: (current_node, visited_mask) -> (weight, path) + # Use list to store active paths: each path is + # (current_node, weight, visited_mask, path_list) + active_paths = [(start_node, 0.0, 1 << start_node, [start_node])] + + for _ in range(max_cycle_length): + if not active_paths: + break + + new_paths = [] + + # Expand each active path + for ( + curr_node, + curr_weight, + visited_mask, + path_list, + ) in active_paths: + # Find outgoing edges from current node + out_mask = edge_index[0] == curr_node + if not out_mask.any(): + continue + + out_edges = edge_index[:, out_mask] + out_weights = edge_weight[out_mask] + + # Try extending path along each outgoing edge + for i in range(out_edges.size(1)): + next_node = out_edges[1, i].item() + edge_w = out_weights[i].item() + + # Check if we've completed a cycle + if next_node == start_node and len(path_list) > 1: + cycle_weight = curr_weight + edge_w + if cycle_weight > best_cycle_weight: + best_cycle_weight = cycle_weight + best_cycle_nodes = torch.tensor( + path_list, + dtype=torch.long, + device=device, + ) + continue + + # Check if node already visited (ensure simple cycle) + if visited_mask & (1 << next_node): + continue + + # Extend path + new_weight = curr_weight + edge_w + new_visited = visited_mask | (1 << next_node) + new_path = path_list + [next_node] + + new_paths.append( + (next_node, new_weight, new_visited, new_path), ) + + # Keep top-k paths by weight to manage memory + if len(new_paths) > top_k_paths: + new_paths.sort(key=lambda x: x[1], reverse=True) + new_paths = new_paths[:top_k_paths] + + active_paths = new_paths + + return best_cycle_nodes, best_cycle_weight + + +def find_all_cycles( + edge_index: Tensor, + edge_weight: Optional[Tensor] = None, + num_nodes: Optional[int] = None, + max_cycle_length: Optional[int] = None, + min_weight: float = float('-inf'), +) -> Tuple[List[Tensor], List[float]]: + r"""Finds all simple cycles in a directed graph with weight above a + threshold. + + This function uses an iterative depth-first search approach to enumerate + all simple cycles in the graph, filtering by minimum weight. + + Args: + edge_index (torch.Tensor): The edge indices in COO format with shape + :obj:`[2, num_edges]`. + edge_weight (torch.Tensor, optional): Edge weights. If :obj:`None`, + all edges are assumed to have weight :obj:`1.0`. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes in the graph. If + :obj:`None`, will be inferred from :obj:`edge_index`. + (default: :obj:`None`) + max_cycle_length (int, optional): Maximum cycle length to search for. + If :obj:`None`, defaults to :obj:`num_nodes`. + (default: :obj:`None`) + min_weight (float, optional): Minimum weight threshold for cycles. + Only cycles with total weight :math:`\geq` :obj:`min_weight` are + returned. (default: :obj:`-inf`) + + Returns: + (list, list): A tuple containing: + + - **cycles** (*list*): List of cycles, where each cycle is a + :obj:`torch.Tensor` of node indices. + - **weights** (*list*): List of corresponding cycle weights. + + Examples: + >>> edge_index = torch.tensor([[0, 1, 2, 1], [1, 2, 0, 0]]) + >>> edge_weight = torch.tensor([1.0, 1.0, 1.0, 2.0]) + >>> cycles, weights = find_all_cycles(edge_index, edge_weight) + >>> for cycle, weight in zip(cycles, weights): + ... print(f"Cycle: {cycle}, Weight: {weight}") + Cycle: tensor([0, 1, 2]), Weight: 3.0 + Cycle: tensor([0, 1]), Weight: 3.0 + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + if edge_weight is None: + edge_weight = torch.ones( + edge_index.size(1), + dtype=torch.float, + device=edge_index.device, + ) + + if max_cycle_length is None: + max_cycle_length = num_nodes + + # Coalesce and sort edges + edge_index, edge_weight = coalesce( + edge_index, + edge_weight, + num_nodes, + num_nodes, + ) + edge_index, edge_weight = sort_edge_index( + edge_index, + edge_weight, + num_nodes=num_nodes, + ) + + device = edge_index.device + all_cycles = [] + all_weights = [] + + # Build adjacency list for efficient traversal + adj_list = [[] for _ in range(num_nodes)] + for i in range(edge_index.size(1)): + src = edge_index[0, i].item() + dst = edge_index[1, i].item() + weight = edge_weight[i].item() + adj_list[src].append((dst, weight)) + + # DFS from each node to find cycles + for start_node in range(num_nodes): + stack = [(start_node, 0.0, 1 << start_node, [start_node])] + + while stack: + ( + curr_node, + curr_weight, + visited_mask, + path_list, + ) = stack.pop() + + if len(path_list) > max_cycle_length: + continue + + # Explore neighbors + for next_node, edge_w in adj_list[curr_node]: + # Check for cycle completion + if next_node == start_node and len(path_list) > 1: + cycle_weight = curr_weight + edge_w + if cycle_weight >= min_weight: + cycle_tensor = torch.tensor( + path_list, + dtype=torch.long, + device=device, + ) + all_cycles.append(cycle_tensor) + all_weights.append(cycle_weight) + continue + + # Check if already visited + if visited_mask & (1 << next_node): + continue + + # Extend path + new_weight = curr_weight + edge_w + new_visited = visited_mask | (1 << next_node) + new_path = path_list + [next_node] + + stack.append((next_node, new_weight, new_visited, new_path)) + + return all_cycles, all_weights + + +__all__ = [ + 'find_max_weight_cycle', + 'find_all_cycles', +]