diff --git a/test/utils/test_num_nodes.py b/test/utils/test_num_nodes.py index ce18d6113315..f02341b52237 100644 --- a/test/utils/test_num_nodes.py +++ b/test/utils/test_num_nodes.py @@ -30,3 +30,31 @@ def test_maybe_num_nodes_dict(): '1': 3, '2': 6, } + + +def test_maybe_num_nodes_dict_empty_edge_index(): + # Test with empty edge indices (regression test for bug fix) + edge_index_dict = { + ('user', 'rates', 'movie'): + torch.tensor([[], []], dtype=torch.long), + ('user', 'follows', 'user'): + torch.tensor([[0, 1], [1, 2]], dtype=torch.long), + } + + result = maybe_num_nodes_dict(edge_index_dict) + assert result == {'user': 3, 'movie': 0} + + # Test with all empty edge indices + edge_index_dict_all_empty = { + ('user', 'rates', 'movie'): torch.tensor([[], []], dtype=torch.long), + ('movie', 'in', 'genre'): torch.tensor([[], []], dtype=torch.long), + } + + result_all_empty = maybe_num_nodes_dict(edge_index_dict_all_empty) + assert result_all_empty == {'user': 0, 'movie': 0, 'genre': 0} + + # Test with provided num_nodes_dict and empty edges + num_nodes_dict = {'movie': 10} + result_with_provided = maybe_num_nodes_dict(edge_index_dict, + num_nodes_dict) + assert result_with_provided == {'user': 3, 'movie': 10} diff --git a/torch_geometric/utils/num_nodes.py b/torch_geometric/utils/num_nodes.py index 1f5e3b6b3392..4be8759d3744 100644 --- a/torch_geometric/utils/num_nodes.py +++ b/torch_geometric/utils/num_nodes.py @@ -52,12 +52,14 @@ def maybe_num_nodes_dict( key = keys[0] if key not in found_types: - N = int(edge_index[0].max() + 1) + N = int(edge_index[0].max() + + 1) if edge_index[0].numel() > 0 else 0 num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) key = keys[-1] if key not in found_types: - N = int(edge_index[1].max() + 1) + N = int(edge_index[1].max() + + 1) if edge_index[1].numel() > 0 else 0 num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) return num_nodes_dict