@@ -13,7 +13,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
1313 const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
1414 int64_t *__restrict__ col, const scalar_t r, const int64_t n,
1515 const int64_t m, const int64_t dim, const int64_t num_examples,
16- const int64_t max_num_neighbors) {
16+ const int64_t max_num_neighbors,
17+ const bool ignore_same_index) {
1718
1819 const int64_t n_y = blockIdx .x * blockDim .x + threadIdx .x ;
1920 if (n_y >= m)
@@ -29,7 +30,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
2930 (x[n_x * dim + d] - y[n_y * dim + d]);
3031 }
3132
32- if (dist < r) {
33+ if (dist < r && !(ignore_same_index && n_y == n_x) ) {
3334 row[n_y * max_num_neighbors + count] = n_y;
3435 col[n_y * max_num_neighbors + count] = n_x;
3536 count++;
@@ -43,7 +44,8 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
4344torch::Tensor radius_cuda (const torch::Tensor x, const torch::Tensor y,
4445 torch::optional<torch::Tensor> ptr_x,
4546 torch::optional<torch::Tensor> ptr_y, const double r,
46- const int64_t max_num_neighbors) {
47+ const int64_t max_num_neighbors,
48+ const bool ignore_same_index) {
4749 CHECK_CUDA (x);
4850 CHECK_CONTIGUOUS (x);
4951 CHECK_INPUT (x.dim () == 2 );
@@ -86,7 +88,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
8688 ptr_x.value ().data_ptr <int64_t >(),
8789 ptr_y.value ().data_ptr <int64_t >(), row.data_ptr <int64_t >(),
8890 col.data_ptr <int64_t >(), r * r, x.size (0 ), y.size (0 ), x.size (1 ),
89- ptr_x.value ().numel () - 1 , max_num_neighbors);
91+ ptr_x.value ().numel () - 1 , max_num_neighbors, ignore_same_index );
9092 });
9193
9294 auto mask = row != -1 ;
0 commit comments