|
| 1 | +import std / [math, algorithm, tables] |
| 2 | +import arraymancer |
| 3 | + |
| 4 | + |
| 5 | +type |
| 6 | + RbfType*[T] = object |
| 7 | + points*: Tensor[float] # (n_points, n_dim) |
| 8 | + values*: Tensor[T] # (n_points, n_values) |
| 9 | + coeffs*: Tensor[float] # (n_points, n_values) |
| 10 | + epsilon*: float |
| 11 | + |
| 12 | +proc distanceMatrix(p1, p2: Tensor[float]): Tensor[float] = |
| 13 | + ## Returns distance matrix of shape (n_points, n_points) |
| 14 | + let n_points = p1.shape[0] |
| 15 | + let n_dims = p1.shape[1] |
| 16 | + result = newTensor[float](n_points, n_points) |
| 17 | + for i in 0 ..< n_points: |
| 18 | + for j in 0 ..< n_points: |
| 19 | + var r2 = 0.0 |
| 20 | + for k in 0 ..< n_dims: |
| 21 | + let diff = p1[i,k] - p2[j,k] |
| 22 | + r2 += diff * diff |
| 23 | + result[i, j] = sqrt(r2) |
| 24 | + |
| 25 | +proc compactRbfFunc*(r: Tensor[float], epsilon: float): Tensor[float] = |
| 26 | + result = map_inline(r): |
| 27 | + (1 - x/epsilon) ^ 4 * (4*x/epsilon + 1) * float(x < epsilon) |
| 28 | + |
| 29 | +proc newRbf*[T](points: Tensor[float], values: Tensor[T], rbfFunc: proc (r: Tensor[float], epsilon: float): Tensor[float] = compactRbfFunc, epsilon: float = 1): RbfType[T] = |
| 30 | + assert points.shape[0] == values.shape[0] |
| 31 | + let dist = distanceMatrix(points, points) |
| 32 | + let A = rbfFunc(dist, epsilon) |
| 33 | + let coeffs = solve(A, values) |
| 34 | + result = RbfType[T](points: points, values: values, coeffs: coeffs, epsilon: epsilon) |
| 35 | + |
| 36 | +let x1 = @[@[0.0, 0.0, 0.0], @[1.0, 1.0, 0.0], @[1.0, 2.0, 0.0]].toTensor |
| 37 | +let x2 = @[@[0.0, 0.0, 1.0], @[1.0, 1.0, 2.0], @[1.0, 2.0, 3.0]].toTensor |
| 38 | +echo distanceMatrix(x1, x1) |
| 39 | +let values = @[0.0, 1.0, 2.0].toTensor |
| 40 | +echo newRbf(x1, values, epsilon=10) |
| 41 | + |
| 42 | +import benchy |
| 43 | + |
| 44 | +let pos = randomTensor(5000, 3, 1.0) |
| 45 | +let vals = randomTensor(5000, 1, 1.0) |
| 46 | +timeIt "Rbf": |
| 47 | + keep newRbf(pos, vals) |
| 48 | + |
0 commit comments