Skip to content

Commit b01f40c

Browse files
usedbytesDennis Kuhnert
andauthored
Fix axis handling in Remove (#7)
* test: Add tests for Remove "axis inversion" In certain circumstances, Remove can lead to an invalid tree. These tests catch that case. * Take axis into account when choosing replacements in Remove This fixes the "axis inversion" issue shown in TestKDTree_RemoveAxisInversion Fixes #6 * Add assertions in tests Co-authored-by: Dennis Kuhnert <dennis.kuhnert@sap.com>
1 parent f5d74e8 commit b01f40c

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

kdtree.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ func (n *node) Remove(p Point, axis int) (*node, *node) {
316316

317317
if n.Left != nil {
318318
largest := n.Left.FindLargest(axis, nil)
319-
removed, sub := n.Left.Remove(largest, axis)
319+
removed, sub := n.Left.Remove(largest, (axis+1)%n.Dimensions())
320320

321321
removed.Left = n.Left
322322
removed.Right = n.Right
@@ -328,7 +328,7 @@ func (n *node) Remove(p Point, axis int) (*node, *node) {
328328

329329
if n.Right != nil {
330330
smallest := n.Right.FindSmallest(axis, nil)
331-
removed, sub := n.Right.Remove(smallest, axis)
331+
removed, sub := n.Right.Remove(smallest, (axis+1)%n.Dimensions())
332332

333333
removed.Left = n.Left
334334
removed.Right = n.Right

kdtree_test.go

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,78 @@ func TestKDTree_RangeSearchWithGenerator(t *testing.T) {
497497
}
498498
}
499499

500+
// TestKDTree_RemoveAxisInversion is a targeted test for issue #6.
501+
//
502+
// https://github.com/kyroy/kdtree/issues/6
503+
//
504+
// Remove wasn't correctly taking into account the axis when searching for
505+
// replacements/substitutes. This caused an incorrect result when removing the
506+
// root node from this tree.
507+
//
508+
// This is because the {171, 176} node starts on the 'left' branch of the
509+
// {238, 155} node, which is correct if indexed by the X axis. When the root
510+
// node is removed, {238, 155} instead becomes indexed on the Y axis, but
511+
// {171, 176} was being left on the 'left' branch.
512+
//
513+
// This test verifies the fix and should help prevent regressions
514+
func TestKDTree_RemoveAxisInversion(t *testing.T) {
515+
tree := kdtree.New([]kdtree.Point{
516+
&Point2D{X: 171, Y: 176},
517+
&Point2D{X: 238, Y: 155},
518+
&Point2D{X: 257, Y: 246},
519+
&Point2D{X: 181, Y: 265},
520+
&Point2D{X: 206, Y: 282},
521+
&Point2D{X: 265, Y: 176},
522+
&Point2D{X: 284, Y: 209},
523+
&Point2D{X: 296, Y: 168},
524+
&Point2D{X: 280, Y: 225},
525+
&Point2D{X: 288, Y: 283},
526+
&Point2D{X: 289, Y: 292},
527+
})
528+
search := &Point2D{X: 150, Y: 218}
529+
remove := &Point2D{X: 265, Y: 176}
530+
531+
tree.Remove(remove)
532+
533+
fewNN := tree.KNN(search, 1)
534+
manyNN := tree.KNN(search, 10)
535+
536+
assertPointsEqual(t, fewNN[0], manyNN[0])
537+
}
538+
539+
func TestKDTree_RemoveAxisInversionGenerator(t *testing.T) {
540+
for dims := 2; dims <= 4; dims++ {
541+
maxSize := int(math.Pow(float64(dims), 4))
542+
543+
tree := kdtree.New(nil)
544+
arr := make([]kdtree.Point, 0, maxSize+1)
545+
for i := 0; i < 1000; i++ {
546+
p := generateTestPoint(dims)
547+
548+
// Two KNN queries
549+
fewNN := tree.KNN(p, 1)
550+
manyNN := tree.KNN(p, maxSize)
551+
552+
if len(arr) > 0 {
553+
assertPointsEqual(t, fewNN[0], manyNN[0])
554+
}
555+
556+
// Add in the new point
557+
arr = append(arr, p)
558+
tree.Insert(p)
559+
560+
// Limit the max number of elements - which will also
561+
// introduce some churn in the tree
562+
if len(arr) > maxSize {
563+
idx := rand.Intn(len(arr))
564+
tree.Remove(arr[idx])
565+
arr[idx] = arr[len(arr)-1]
566+
arr = arr[:len(arr)-1]
567+
}
568+
}
569+
}
570+
}
571+
500572
// benchmarks
501573

502574
var resultTree *kdtree.KDTree
@@ -559,6 +631,15 @@ func generateTestCaseData(size int) []kdtree.Point {
559631
return points
560632
}
561633

634+
func generateTestPoint(dimensions int) kdtree.Point {
635+
r := rand.New(rand.NewSource(time.Now().UnixNano()))
636+
values := make([]float64, dimensions)
637+
for j := range values {
638+
values[j] = r.Float64()*3000 - 1500
639+
}
640+
return NewPoint(values, nil)
641+
}
642+
562643
func prioQueueKNN(points []kdtree.Point, p kdtree.Point, k int) []kdtree.Point {
563644
knn := make([]kdtree.Point, 0, k)
564645
if p == nil {
@@ -607,6 +688,6 @@ func distance(p1, p2 kdtree.Point) float64 {
607688
func assertPointsEqual(t *testing.T, p1 kdtree.Point, p2 kdtree.Point) {
608689
assert.Equal(t, p1.Dimensions(), p2.Dimensions())
609690
for i := 0; i < p1.Dimensions(); i++ {
610-
assert.Equal(t, p1.Dimension(i), p2.Dimension(i))
691+
assert.Equal(t, p1.Dimension(i), p2.Dimension(i), "assert equal dimension %d", i)
611692
}
612693
}

0 commit comments

Comments
 (0)