Skip to content

Commit d16c692

Browse files
authored
Fix the number of neighbors bug in radius_graph (#228)
* Fix the number of neighbors bug in `radius_graph` Signed-off-by: Xuangui Huang <xuanguih@nvidia.com> * fix linting issue --------- Signed-off-by: Xuangui Huang <xuanguih@nvidia.com>
1 parent e1e788b commit d16c692

File tree

7 files changed

+81
-25
lines changed

7 files changed

+81
-25
lines changed

csrc/cpu/radius_cpu.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
88
torch::optional<torch::Tensor> ptr_x,
99
torch::optional<torch::Tensor> ptr_y, double r,
10-
int64_t max_num_neighbors, int64_t num_workers) {
10+
int64_t max_num_neighbors, int64_t num_workers,
11+
bool ignore_same_index) {
1112

1213
CHECK_CPU(x);
1314
CHECK_INPUT(x.dim() == 2);
@@ -54,10 +55,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
5455
size_t num_matches = mat_index.index->radiusSearch(
5556
y_data + i * y.size(1), r * r, ret_matches, params);
5657

57-
for (size_t j = 0; j < std::min(num_matches, (size_t)max_num_neighbors);
58-
j++) {
59-
out_vec.push_back(ret_matches[j].first);
60-
out_vec.push_back(i);
58+
for (size_t j = 0, count = 0;
59+
j < num_matches && count < (size_t)max_num_neighbors;
60+
j++) {
61+
if (!ignore_same_index || ret_matches[j].first != i) {
62+
out_vec.push_back(ret_matches[j].first);
63+
out_vec.push_back(i);
64+
count++;
65+
}
6166
}
6267
}
6368

@@ -91,10 +96,14 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
9196
size_t num_matches = mat_index.index->radiusSearch(
9297
y_data + i * y.size(1), r * r, ret_matches, params);
9398

94-
for (size_t j = 0;
95-
j < std::min(num_matches, (size_t)max_num_neighbors); j++) {
96-
out_vec.push_back(x_start + ret_matches[j].first);
97-
out_vec.push_back(i);
99+
for (size_t j = 0, count = 0;
100+
j < num_matches && count < (size_t)max_num_neighbors;
101+
j++) {
102+
if (!ignore_same_index || x_start + ret_matches[j].first != i) {
103+
out_vec.push_back(x_start + ret_matches[j].first);
104+
out_vec.push_back(i);
105+
count++;
106+
}
98107
}
99108
}
100109
}

csrc/cpu/radius_cpu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
66
torch::optional<torch::Tensor> ptr_x,
77
torch::optional<torch::Tensor> ptr_y, double r,
8-
int64_t max_num_neighbors, int64_t num_workers);
8+
int64_t max_num_neighbors, int64_t num_workers,
9+
bool ignore_same_index);

csrc/cuda/radius_cuda.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
1313
const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
1414
int64_t *__restrict__ col, const scalar_t r, const int64_t n,
1515
const int64_t m, const int64_t dim, const int64_t num_examples,
16-
const int64_t max_num_neighbors) {
16+
const int64_t max_num_neighbors,
17+
const bool ignore_same_index) {
1718

1819
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
1920
if (n_y >= m)
@@ -29,7 +30,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
2930
(x[n_x * dim + d] - y[n_y * dim + d]);
3031
}
3132

32-
if (dist < r) {
33+
if (dist < r && !(ignore_same_index && n_y == n_x)) {
3334
row[n_y * max_num_neighbors + count] = n_y;
3435
col[n_y * max_num_neighbors + count] = n_x;
3536
count++;
@@ -43,7 +44,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
4344
torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
4445
torch::optional<torch::Tensor> ptr_x,
4546
torch::optional<torch::Tensor> ptr_y, const double r,
46-
const int64_t max_num_neighbors) {
47+
const int64_t max_num_neighbors,
48+
const bool ignore_same_index) {
4749
CHECK_CUDA(x);
4850
CHECK_CONTIGUOUS(x);
4951
CHECK_INPUT(x.dim() == 2);
@@ -86,7 +88,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
8688
ptr_x.value().data_ptr<int64_t>(),
8789
ptr_y.value().data_ptr<int64_t>(), row.data_ptr<int64_t>(),
8890
col.data_ptr<int64_t>(), r * r, x.size(0), y.size(0), x.size(1),
89-
ptr_x.value().numel() - 1, max_num_neighbors);
91+
ptr_x.value().numel() - 1, max_num_neighbors, ignore_same_index);
9092
});
9193

9294
auto mask = row != -1;

csrc/cuda/radius_cuda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
66
torch::optional<torch::Tensor> ptr_x,
77
torch::optional<torch::Tensor> ptr_y, double r,
8-
int64_t max_num_neighbors);
8+
int64_t max_num_neighbors,
9+
bool ignore_same_index);

csrc/radius.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ PyMODINIT_FUNC PyInit__radius_cpu(void) { return NULL; }
2222
CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y,
2323
torch::optional<torch::Tensor> ptr_x,
2424
torch::optional<torch::Tensor> ptr_y, double r,
25-
int64_t max_num_neighbors, int64_t num_workers) {
25+
int64_t max_num_neighbors, int64_t num_workers,
26+
bool ignore_same_index) {
2627
if (x.device().is_cuda()) {
2728
#ifdef WITH_CUDA
28-
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors);
29+
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors, ignore_same_index);
2930
#else
3031
AT_ERROR("Not compiled with CUDA support");
3132
#endif
3233
} else {
33-
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers);
34+
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors, num_workers, ignore_same_index);
3435
}
3536
}
3637

test/test_radius.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ def to_set(edge_index):
1111
return set([(i, j) for i, j in edge_index.t().tolist()])
1212

1313

14+
def to_degree(edge_index):
15+
_, counts = torch.unique(edge_index[1], return_counts=True)
16+
return counts.tolist()
17+
18+
19+
def to_batch(nodes):
20+
return [int(i / 4) for i in nodes]
21+
22+
1423
@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
1524
def test_radius(dtype, device):
1625
x = tensor([
@@ -74,6 +83,38 @@ def test_radius_graph(dtype, device):
7483
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
7584
(3, 2), (0, 3), (2, 3)])
7685

86+
edge_index = radius_graph(x, r=100, flow='source_to_target',
87+
max_num_neighbors=1)
88+
assert set(to_degree(edge_index)) == set([1])
89+
90+
x = tensor([
91+
[-1, -1],
92+
[-1, -1],
93+
[-1, -1],
94+
[-1, -1],
95+
], dtype, device)
96+
97+
edge_index = radius_graph(x, r=100, flow='source_to_target',
98+
max_num_neighbors=1)
99+
assert set(to_degree(edge_index)) == set([1])
100+
101+
x = tensor([
102+
[-1, -1],
103+
[-1, +1],
104+
[+1, +1],
105+
[+1, -1],
106+
[-1, -1],
107+
[-1, +1],
108+
[+1, +1],
109+
[+1, -1],
110+
], dtype, device)
111+
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
112+
113+
edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target',
114+
max_num_neighbors=1)
115+
assert set(to_degree(edge_index)) == set([1])
116+
assert to_batch(edge_index[0]) == batch_x.tolist()
117+
77118

78119
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
79120
def test_radius_graph_large(dtype, device):

torch_cluster/radius.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def radius(
1212
max_num_neighbors: int = 32,
1313
num_workers: int = 1,
1414
batch_size: Optional[int] = None,
15+
ignore_same_index: bool = False
1516
) -> torch.Tensor:
1617
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
1718
distance :obj:`r`.
@@ -40,6 +41,9 @@ def radius(
4041
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
4142
batch_size (int, optional): The number of examples :math:`B`.
4243
Automatically calculated if not given. (default: :obj:`None`)
44+
ignore_same_index (bool, optional): If :obj:`True`, each element in
45+
:obj:`y` ignores the point in :obj:`x` with the same index.
46+
(default: :obj:`False`)
4347
4448
.. code-block:: python
4549
@@ -80,7 +84,8 @@ def radius(
8084
ptr_y = torch.bucketize(arange, batch_y)
8185

8286
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
83-
max_num_neighbors, num_workers)
87+
max_num_neighbors, num_workers,
88+
ignore_same_index)
8489

8590

8691
def radius_graph(
@@ -133,15 +138,11 @@ def radius_graph(
133138

134139
assert flow in ['source_to_target', 'target_to_source']
135140
edge_index = radius(x, x, r, batch, batch,
136-
max_num_neighbors if loop else max_num_neighbors + 1,
137-
num_workers, batch_size)
141+
max_num_neighbors,
142+
num_workers, batch_size, not loop)
138143
if flow == 'source_to_target':
139144
row, col = edge_index[1], edge_index[0]
140145
else:
141146
row, col = edge_index[0], edge_index[1]
142147

143-
if not loop:
144-
mask = row != col
145-
row, col = row[mask], col[mask]
146-
147148
return torch.stack([row, col], dim=0)

0 commit comments

Comments
 (0)