Skip to content

Commit 518fb76

Browse files
committed
refactor RbfGrid to indice based
1 parent b5a6b49 commit 518fb76

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/numericalnim/rbf.nim

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ type
1212
f*: RbfFunc
1313

1414
RbfGrid*[T] = object
15-
grid*: seq[seq[T]]
15+
indices*: seq[seq[int]]
16+
values*: Tensor[T]
17+
points*: Tensor[float]
1618
gridSize*, gridDim*: int
1719
gridDelta*: float
1820

@@ -25,24 +27,29 @@ type
2527
template km(point: Tensor[float], index: int, delta: float): int =
2628
int(ceil(point[0, index] / delta))
2729

30+
iterator neighbours*[T](grid: RbfGrid[T], k: int): int =
31+
discard
32+
2833
proc findIndex*[T](grid: RbfGrid[T], point: Tensor[float]): int =
2934
result = km(point, grid.gridDim - 1, grid.gridDelta) - 1
3035
for i in 0 ..< grid.gridDim - 1:
3136
result += (km(point, i, grid.gridDelta) - 1) * grid.gridSize ^ (grid.gridDim - i - 1)
3237

33-
proc newRbfGrid*[T](points: Tensor[float], values: seq[T], gridSize: int = 0): RbfGrid[T] =
38+
proc newRbfGrid*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0): RbfGrid[T] =
3439
let nPoints = points.shape[0]
3540
let nDims = points.shape[1]
3641
let gridSize =
3742
if gridSize > 0:
3843
gridSize
3944
else:
40-
int(round(pow(nPoints.float, 1 / nDims) / 2))
45+
max(int(round(pow(nPoints.float, 1 / nDims) / 2)), 1)
4146
let delta = 1 / gridSize
42-
result = RbfGrid[T](gridSize: gridSize, gridDim: nDims, gridDelta: delta, grid: newSeq[seq[T]](gridSize ^ nDims))
47+
result = RbfGrid[T](gridSize: gridSize, gridDim: nDims, gridDelta: delta, indices: newSeq[seq[int]](gridSize ^ nDims))
4348
for row in 0 ..< nPoints:
4449
let index = result.findIndex(points[row, _])
45-
result.grid[index].add values[row]
50+
result.indices[index].add row
51+
result.values = values
52+
result.points = points
4653

4754

4855
# Idea: blocked distance matrix for better cache friendliness
@@ -106,6 +113,6 @@ when isMainModule:
106113
let rbfPu = newRbfPu(x1, values, 3)
107114

108115
echo "----------------"
109-
let xGrid = [[0.1, 0.1], [0.9, 0.9], [0.4, 0.4]].toTensor
110-
let valuesGrid = @[0, 9, 5]
116+
let xGrid = [[0.1, 0.1], [0.2, 0.3], [0.9, 0.9], [0.4, 0.4]].toTensor
117+
let valuesGrid = @[0, 1, 9, 5].toTensor.reshape(4, 1)
111118
echo newRbfGrid(xGrid, valuesGrid, 3)

0 commit comments

Comments
 (0)