Skip to content

Commit c22abde

Browse files
committed
implement neighbours iterator
1 parent e9bab2e commit c22abde

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

src/numericalnim/rbf.nim

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,31 @@ template km(point: Tensor[float], index: int, delta: float): int =
2828
int(ceil(point[0, index] / delta))
2929

3030
iterator neighbours*[T](grid: RbfGrid[T], k: int): int =
31-
discard
32-
33-
iterator neighboursIncludingCenter*[T](grid: RbfGrid[T], k: int): int =
34-
yield k
31+
# TODO: Create product iterator that doesn't need to allocate 3^gridDim seqs
32+
for dir in product(@[@[-1, 0, 1]].cycle(grid.gridDim)):
33+
block loopBody:
34+
var kNeigh = k
35+
for i, x in dir:
36+
let step = grid.gridSize ^ (grid.gridDim - i - 1)
37+
if i == dir.high and k mod grid.gridSize == 0 and x == -1:
38+
break loopBody
39+
elif i == dir.high and k mod grid.gridSize == grid.gridSize - 1 and x == 1:
40+
break loopBody
41+
elif k div step == 0 and x == -1:
42+
break loopBody
43+
elif k div step == grid.gridSize - 1 and x == 1:
44+
break loopBody
45+
else:
46+
kNeigh += x * step
47+
if kNeigh >= 0 and kNeigh < grid.gridSize ^ grid.gridDim:
48+
echo kNeigh, ": ", dir
49+
yield kNeigh
50+
51+
52+
iterator neighboursExcludingCenter*[T](grid: RbfGrid[T], k: int): int =
3553
for x in grid.neighbours(k):
36-
yield x
54+
if x != k:
55+
yield x
3756

3857
proc findIndex*[T](grid: RbfGrid[T], point: Tensor[float]): int =
3958
result = km(point, grid.gridDim - 1, grid.gridDelta) - 1
@@ -53,7 +72,7 @@ template dist2(p1, p2: Tensor[float]): float =
5372
proc findAllWithin*[T](grid: RbfGrid[T], x: Tensor[float], rho: float): seq[int] =
5473
assert x.shape.len == 2 and x.shape[0] == 1
5574
let index = grid.findIndex(x)
56-
for k in grid.neighboursIncludingCenter(index):
75+
for k in grid.neighbours(index):
5776
for i in grid.indices[k]:
5877
if dist2(x, grid.points[i, _]) <= rho*rho:
5978
result.add i
@@ -105,10 +124,7 @@ proc newRbf*[T](points: Tensor[float], values: Tensor[T], rbfFunc: RbfFunc = com
105124

106125
proc eval*[T](rbf: RbfType[T], x: Tensor[float]): Tensor[T] =
107126
let dist = distanceMatrix(rbf.points, x)
108-
echo dist
109127
let A = rbf.f(dist, rbf.epsilon)
110-
echo A
111-
echo "---------------"
112128
result = A * rbf.coeffs
113129

114130
proc scalePoint*(x: Tensor[float], limits: tuple[upper: Tensor[float], lower: Tensor[float]]): Tensor[float] =
@@ -172,12 +188,14 @@ when isMainModule:
172188
#let vals = randomTensor(5000, 1, 1.0)
173189
# timeIt "Rbf":
174190
# keep newRbf(pos, vals)
175-
let rbfPu = newRbfPu(x1, values, 3)
176-
echo rbfPu.grid.values[1, 0]
191+
#let rbfPu = newRbfPu(x1, values, 3)
192+
#echo rbfPu.grid.values[1, 0]
177193
#echo rbfPu.eval(x1[[2, 1, 0], _])
178194

179-
echo rbfPu.eval(sqrt x1)
195+
#echo rbfPu.eval(sqrt x1)
180196
echo "----------------"
181-
#let xGrid = [[0.1, 0.1], [0.2, 0.3], [0.9, 0.9], [0.4, 0.4]].toTensor
182-
#let valuesGrid = @[0, 1, 9, 5].toTensor.reshape(4, 1)
183-
#echo newRbfGrid(xGrid, valuesGrid, 3)
197+
let xGrid = [[0.1, 0.1], [0.2, 0.3], [0.9, 0.9], [0.4, 0.4]].toTensor
198+
let valuesGrid = @[0, 1, 9, 5].toTensor.reshape(4, 1)
199+
let grid = newRbfGrid(xGrid, valuesGrid, 3)
200+
echo grid
201+
echo grid.neighbours(8).toSeq

0 commit comments

Comments
 (0)