Skip to content

Commit 91f809b

Browse files
committed
rename rbf and rbfPu
1 parent 3c268e5 commit 91f809b

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/numericalnim/rbf.nim

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import ./utils
55

66
type
77
RbfFunc* = proc (r: Tensor[float], epsilon: float): Tensor[float]
8-
RbfType*[T] = object
8+
RbfBaseType*[T] = object
99
points*: Tensor[float] # (n_points, n_dim)
1010
values*: Tensor[T] # (n_points, n_values)
1111
coeffs*: Tensor[float] # (n_points, n_values)
@@ -19,9 +19,9 @@ type
1919
gridSize*, gridDim*: int
2020
gridDelta*: float
2121

22-
RbfPUType*[T] = object
22+
RbfType*[T] = object
2323
limits*: tuple[upper: Tensor[float], lower: Tensor[float]]
24-
grid*: RbfGrid[RbfType[T]]
24+
grid*: RbfGrid[RbfBaseType[T]]
2525
nValues*: int
2626

2727
template km(point: Tensor[float], index: int, delta: float): int =
@@ -131,14 +131,14 @@ proc compactRbfFunc*(r: Tensor[float], epsilon: float): Tensor[float] =
131131
let temp2 = temp * temp
132132
temp2*temp2 * (4*xeps + 1) * float(xeps < 1)
133133

134-
proc newRbf*[T](points: Tensor[float], values: Tensor[T], rbfFunc: RbfFunc = compactRbfFunc, epsilon: float = 1): RbfType[T] =
134+
proc newRbfBase*[T](points: Tensor[float], values: Tensor[T], rbfFunc: RbfFunc = compactRbfFunc, epsilon: float = 1): RbfBaseType[T] =
135135
assert points.shape[0] == values.shape[0]
136136
let dist = distanceMatrix(points, points)
137137
let A = rbfFunc(dist, epsilon)
138138
let coeffs = solve(A, values)
139-
result = RbfType[T](points: points, values: values, coeffs: coeffs, epsilon: epsilon, f: rbfFunc)
139+
result = RbfBaseType[T](points: points, values: values, coeffs: coeffs, epsilon: epsilon, f: rbfFunc)
140140

141-
proc eval*[T](rbf: RbfType[T], x: Tensor[float]): Tensor[T] =
141+
proc eval*[T](rbf: RbfBaseType[T], x: Tensor[float]): Tensor[T] =
142142
let dist = distanceMatrix(rbf.points, x)
143143
let A = rbf.f(dist, rbf.epsilon)
144144
result = A * rbf.coeffs
@@ -148,7 +148,7 @@ proc scalePoint*(x: Tensor[float], limits: tuple[upper: Tensor[float], lower: Te
148148
let upper = limits.upper +. 0.01
149149
(x -. lower) /. (upper - lower)
150150

151-
proc newRbfPu*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, rbfFunc: RbfFunc = compactRbfFunc, epsilon: float = 1): RbfPUType[T] =
151+
proc newRbf*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, rbfFunc: RbfFunc = compactRbfFunc, epsilon: float = 1): RbfType[T] =
152152
assert points.shape[0] == values.shape[0]
153153
assert points.shape.len == 2 and values.shape.len == 2
154154
let upperLimit = max(points, 0)
@@ -158,18 +158,18 @@ proc newRbfPu*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, r
158158
let dataGrid = newRbfGrid(scaledPoints, values, gridSize)
159159
let patchPoints = dataGrid.constructMeshedPatches()
160160
let nPatches = patchPoints.shape[0]
161-
var patchRbfs: seq[RbfType[T]] #= newTensor[RbfType[T]](nPatches, 1)
161+
var patchRbfs: seq[RbfBaseType[T]] #= newTensor[RbfBaseType[T]](nPatches, 1)
162162
var patchIndices: seq[int]
163163
for i in 0 ..< nPatches:
164164
let indices = dataGrid.findAllWithin(patchPoints[i, _], dataGrid.gridDelta)
165165
if indices.len > 0:
166-
patchRbfs.add newRbf(dataGrid.points[indices,_], values[indices, _], epsilon=epsilon)
166+
patchRbfs.add newRbfBase(dataGrid.points[indices,_], values[indices, _], epsilon=epsilon)
167167
patchIndices.add i
168168

169169
let patchGrid = newRbfGrid(patchPoints[patchIndices, _], patchRbfs.toTensor.unsqueeze(1), gridSize)
170-
result = RbfPUType[T](limits: limits, grid: patchGrid, nValues: values.shape[1])
170+
result = RbfType[T](limits: limits, grid: patchGrid, nValues: values.shape[1])
171171

172-
proc eval*[T](rbf: RbfPUType[T], x: Tensor[float]): Tensor[T] =
172+
proc eval*[T](rbf: RbfType[T], x: Tensor[float]): Tensor[T] =
173173
assert x.shape.len == 2
174174
assert (not ((x <=. rbf.limits.upper) and (x >=. rbf.limits.lower))).astype(int).sum() == 0, "Some of your points are outside the allowed limits"
175175

@@ -192,7 +192,7 @@ proc eval*[T](rbf: RbfPUType[T], x: Tensor[float]): Tensor[T] =
192192
else:
193193
result[row, _] = T(Nan) # allow to pass default value to newRbfPU?
194194

195-
proc evalAlt*[T](rbf: RbfPUType[T], x: Tensor[float]): Tensor[T] =
195+
proc evalAlt*[T](rbf: RbfType[T], x: Tensor[float]): Tensor[T] =
196196
assert x.shape.len == 2
197197
assert (not ((x <=. rbf.limits.upper) and (x >=. rbf.limits.lower))).astype(int).sum() == 0, "Some of your points are outside the allowed limits"
198198

tests/test_interpolate.nim

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,10 @@ test "Trilinear f = x*y*z T: Tensor[float]":
411411
check abs(spline.eval(i, j, k)[1] - i*j*k) < 1e-12
412412
check abs(spline.eval(i, j, k)[2] - 1) < 1e-16
413413

414-
test "rbf f=x*y*z":
414+
test "rbfBase f=x*y*z":
415415
let pos = meshgrid(arraymancer.linspace(0.0, 1.0, 5), arraymancer.linspace(0.0, 1.0, 5), arraymancer.linspace(0.0, 1.0, 5))
416416
let vals = pos[_, 0] *. pos[_, 1] *. pos[_, 2]
417-
let rbfObj = newRbf(pos, vals)
417+
let rbfObj = newRbfBase(pos, vals)
418418

419419
# We want test points in the interior to avoid the edges
420420
let xTest = meshgrid(arraymancer.linspace(0.1, 0.9, 10), arraymancer.linspace(0.1, 0.9, 10), arraymancer.linspace(0.1, 0.9, 10))
@@ -424,10 +424,10 @@ test "rbf f=x*y*z":
424424
check x < 0.16
425425
check mean_squared_error(yTest, yCorrect) < 2e-4
426426

427-
test "rbfPu f=x*y*z":
427+
test "rbf f=x*y*z":
428428
let pos = meshgrid(arraymancer.linspace(0.0, 1.0, 5), arraymancer.linspace(0.0, 1.0, 5), arraymancer.linspace(0.0, 1.0, 5))
429429
let vals = pos[_, 0] *. pos[_, 1] *. pos[_, 2]
430-
let rbfObj = newRbfPu(pos, vals)
430+
let rbfObj = newRbf(pos, vals)
431431

432432
# We want test points in the interior to avoid the edges
433433
let xTest = meshgrid(arraymancer.linspace(0.1, 0.9, 10), arraymancer.linspace(0.1, 0.9, 10), arraymancer.linspace(0.1, 0.9, 10))

0 commit comments

Comments
 (0)