Skip to content

Commit 04d3227

Browse files
Merge pull request #34 from SciNim/rbf
fix rbf bug
2 parents a17b80c + 75a8774 commit 04d3227

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

src/numericalnim/rbf.nim

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ proc findIndex*[T](grid: RbfGrid[T], point: Tensor[float]): int =
5656
result += (km(point, i, grid.gridDelta) - 1) * grid.gridSize ^ (grid.gridDim - i - 1)
5757

5858
proc constructMeshedPatches*[T](grid: RbfGrid[T]): Tensor[float] =
59-
meshgrid(@[arraymancer.linspace(0 + grid.gridDelta / 2, 1 - grid.gridDelta / 2, grid.gridSize)].cycle(grid.gridDim))
59+
if grid.gridSize == 1:
60+
@[@[0.5].cycle(grid.gridDim)].toTensor
61+
else:
62+
meshgrid(@[arraymancer.linspace(0 + grid.gridDelta / 2, 1 - grid.gridDelta / 2, grid.gridSize)].cycle(grid.gridDim))
6063

6164
template dist2(p1, p2: Tensor[float]): float =
6265
var result = 0.0
@@ -85,14 +88,16 @@ proc findAllBetween*[T](grid: RbfGrid[T], x: Tensor[float], rho1, rho2: float):
8588
if rho1*rho1 <= d and d <= rho2*rho2:
8689
result.add i
8790

91+
proc calcGridSize(nPoints, nDims: int, gridSize: int): int =
92+
if gridSize > 0:
93+
gridSize
94+
else:
95+
max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1)
96+
8897
proc newRbfGrid*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0): RbfGrid[T] =
8998
let nPoints = points.shape[0]
9099
let nDims = points.shape[1]
91-
let gridSize =
92-
if gridSize > 0:
93-
gridSize
94-
else:
95-
max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1)
100+
let gridSize = calcGridSize(nPoints, nDims, gridSize)
96101
let delta = 1 / gridSize
97102
result = RbfGrid[T](gridSize: gridSize, gridDim: nDims, gridDelta: delta, indices: newSeq[seq[int]](gridSize ^ nDims))
98103
for row in 0 ..< nPoints:
@@ -152,13 +157,20 @@ proc newRbf*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, rbf
152157
## epsilon: shape parameter. Default 1.
153158
assert points.shape[0] == values.shape[0]
154159
assert points.shape.len == 2 and values.shape.len == 2
160+
155161
let upperLimit = max(points, 0)
156162
let lowerLimit = min(points, 0)
157163
let limits = (upper: upperLimit, lower: lowerLimit)
158164
let scaledPoints = points.scalePoint(limits)
165+
166+
let nPoints = points.shape[0]
167+
let nDims = points.shape[1]
168+
let gridSize = calcGridSize(nPoints, nDims, gridSize)
169+
159170
let dataGrid = newRbfGrid(scaledPoints, values, gridSize)
160171
let patchPoints = dataGrid.constructMeshedPatches()
161172
let nPatches = patchPoints.shape[0]
173+
162174
var patchRbfs: seq[RbfBaseType[T]] #= newTensor[RbfBaseType[T]](nPatches, 1)
163175
var patchIndices: seq[int]
164176
for i in 0 ..< nPatches:

tests/test_interpolate.nim

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,6 @@ test "rbf f=x*y*z":
434434
let yTest = rbfObj.eval(xTest)
435435
let yCorrect = xTest[_, 0] *. xTest[_, 1] *. xTest[_, 2]
436436
for x in abs(yCorrect - yTest):
437-
check x < 0.03
438-
check mean_squared_error(yTest, yCorrect) < 1e-4
437+
check x < 0.11
438+
check mean_squared_error(yTest, yCorrect) < 1.4e-4
439439

0 commit comments

Comments
 (0)