Skip to content

Commit fbcdea7

Browse files
committed
implement simple rbf
1 parent aaa7bb6 commit fbcdea7

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

src/numericalnim/rbf.nim

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import std / [math, algorithm, tables]
2+
import arraymancer
3+
4+
5+
type
6+
RbfType*[T] = object
7+
points*: Tensor[float] # (n_points, n_dim)
8+
values*: Tensor[T] # (n_points, n_values)
9+
coeffs*: Tensor[float] # (n_points, n_values)
10+
epsilon*: float
11+
12+
proc distanceMatrix(p1, p2: Tensor[float]): Tensor[float] =
13+
## Returns distance matrix of shape (n_points, n_points)
14+
let n_points = p1.shape[0]
15+
let n_dims = p1.shape[1]
16+
result = newTensor[float](n_points, n_points)
17+
for i in 0 ..< n_points:
18+
for j in 0 ..< n_points:
19+
var r2 = 0.0
20+
for k in 0 ..< n_dims:
21+
let diff = p1[i,k] - p2[j,k]
22+
r2 += diff * diff
23+
result[i, j] = sqrt(r2)
24+
25+
proc compactRbfFunc*(r: Tensor[float], epsilon: float): Tensor[float] =
26+
result = map_inline(r):
27+
(1 - x/epsilon) ^ 4 * (4*x/epsilon + 1) * float(x < epsilon)
28+
29+
proc newRbf*[T](points: Tensor[float], values: Tensor[T], rbfFunc: proc (r: Tensor[float], epsilon: float): Tensor[float] = compactRbfFunc, epsilon: float = 1): RbfType[T] =
30+
assert points.shape[0] == values.shape[0]
31+
let dist = distanceMatrix(points, points)
32+
let A = rbfFunc(dist, epsilon)
33+
let coeffs = solve(A, values)
34+
result = RbfType[T](points: points, values: values, coeffs: coeffs, epsilon: epsilon)
35+
36+
let x1 = @[@[0.0, 0.0, 0.0], @[1.0, 1.0, 0.0], @[1.0, 2.0, 0.0]].toTensor
37+
let x2 = @[@[0.0, 0.0, 1.0], @[1.0, 1.0, 2.0], @[1.0, 2.0, 3.0]].toTensor
38+
echo distanceMatrix(x1, x1)
39+
let values = @[0.0, 1.0, 2.0].toTensor
40+
echo newRbf(x1, values, epsilon=10)
41+
42+
import benchy
43+
44+
let pos = randomTensor(5000, 3, 1.0)
45+
let vals = randomTensor(5000, 1, 1.0)
46+
timeIt "Rbf":
47+
keep newRbf(pos, vals)
48+

0 commit comments

Comments
 (0)