Skip to content

Commit 2613b49

Browse files
committed
fix neighbours iterator and add test
1 parent c22abde commit 2613b49

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/numericalnim/rbf.nim

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import std / [math, algorithm, tables, sequtils]
1+
import std / [math, algorithm, tables, sequtils, strutils]
22
import arraymancer
33
import ./utils
44

@@ -34,18 +34,13 @@ iterator neighbours*[T](grid: RbfGrid[T], k: int): int =
3434
var kNeigh = k
3535
for i, x in dir:
3636
let step = grid.gridSize ^ (grid.gridDim - i - 1)
37-
if i == dir.high and k mod grid.gridSize == 0 and x == -1:
37+
if (k div step) mod grid.gridSize == 0 and x == -1:
3838
break loopBody
39-
elif i == dir.high and k mod grid.gridSize == grid.gridSize - 1 and x == 1:
40-
break loopBody
41-
elif k div step == 0 and x == -1:
42-
break loopBody
43-
elif k div step == grid.gridSize - 1 and x == 1:
39+
elif (k div step) mod grid.gridSize == grid.gridSize - 1 and x == 1:
4440
break loopBody
4541
else:
4642
kNeigh += x * step
4743
if kNeigh >= 0 and kNeigh < grid.gridSize ^ grid.gridDim:
48-
echo kNeigh, ": ", dir
4944
yield kNeigh
5045

5146

@@ -147,7 +142,7 @@ proc newRbfPu*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, r
147142
for i in 0 ..< nPatches:
148143
let indices = dataGrid.findAllWithin(patchPoints[i, _], dataGrid.gridDelta)
149144
if indices.len > 0:
150-
patchRbfs.add newRbf(dataGrid.points[indices,_], values[indices, _], epsilon=dataGrid.gridDelta)
145+
patchRbfs.add newRbf(dataGrid.points[indices,_], values[indices, _], epsilon=epsilon)
151146
patchIndices.add i
152147

153148
let patchGrid = newRbfGrid(patchPoints[patchIndices, _], patchRbfs.toTensor.unsqueeze(1), gridSize)
@@ -194,8 +189,8 @@ when isMainModule:
194189

195190
#echo rbfPu.eval(sqrt x1)
196191
echo "----------------"
197-
let xGrid = [[0.1, 0.1], [0.2, 0.3], [0.9, 0.9], [0.4, 0.4]].toTensor
192+
#[ let xGrid = [[0.1, 0.1], [0.2, 0.3], [0.9, 0.9], [0.4, 0.4]].toTensor
198193
let valuesGrid = @[0, 1, 9, 5].toTensor.reshape(4, 1)
199-
let grid = newRbfGrid(xGrid, valuesGrid, 3)
194+
let grid = newRbfGrid(xGrid, valuesGrid, 2)
200195
echo grid
201-
echo grid.neighbours(8).toSeq
196+
echo grid.neighbours(3).toSeq ]#

tests/test_interpolate.nim

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,17 @@ test "rbf f=x*y*z":
423423
for x in abs(yCorrect - yTest):
424424
check x < 0.16
425425
check mean_squared_error(yTest, yCorrect) < 2e-4
426+
427+
test "rbfPu f=x*y*z":
428+
let pos = meshgrid(arraymancer.linspace(0.0, 1.0, 5), arraymancer.linspace(0.0, 1.0, 5), arraymancer.linspace(0.0, 1.0, 5))
429+
let vals = pos[_, 0] *. pos[_, 1] *. pos[_, 2]
430+
let rbfObj = newRbfPu(pos, vals)
431+
432+
# We want test points in the interior to avoid the edges
433+
let xTest = meshgrid(arraymancer.linspace(0.1, 0.9, 10), arraymancer.linspace(0.1, 0.9, 10), arraymancer.linspace(0.1, 0.9, 10))
434+
let yTest = rbfObj.eval(xTest)
435+
let yCorrect = xTest[_, 0] *. xTest[_, 1] *. xTest[_, 2]
436+
for x in abs(yCorrect - yTest):
437+
check x < 0.03
438+
check mean_squared_error(yTest, yCorrect) < 1e-4
439+

0 commit comments

Comments
 (0)