@@ -56,7 +56,10 @@ proc findIndex*[T](grid: RbfGrid[T], point: Tensor[float]): int =
5656 result += (km (point, i, grid.gridDelta) - 1 ) * grid.gridSize ^ (grid.gridDim - i - 1 )
5757
5858proc constructMeshedPatches * [T](grid: RbfGrid [T]): Tensor [float ] =
59- meshgrid (@ [arraymancer.linspace (0 + grid.gridDelta / 2 , 1 - grid.gridDelta / 2 , grid.gridSize)].cycle (grid.gridDim))
59+ if grid.gridSize == 1 :
60+ @ [@ [0.5 ].cycle (grid.gridDim)].toTensor
61+ else :
62+ meshgrid (@ [arraymancer.linspace (0 + grid.gridDelta / 2 , 1 - grid.gridDelta / 2 , grid.gridSize)].cycle (grid.gridDim))
6063
6164template dist2 (p1, p2: Tensor [float ]): float =
6265 var result = 0.0
@@ -85,14 +88,16 @@ proc findAllBetween*[T](grid: RbfGrid[T], x: Tensor[float], rho1, rho2: float):
8588 if rho1* rho1 <= d and d <= rho2* rho2:
8689 result .add i
8790
91+ proc calcGridSize (nPoints, nDims: int , gridSize: int ): int =
92+ if gridSize > 0 :
93+ gridSize
94+ else :
95+ max (int (round (pow (nPoints.float , 1 / nDims) / 2 )), 1 )
96+
8897proc newRbfGrid * [T](points: Tensor [float ], values: Tensor [T], gridSize: int = 0 ): RbfGrid [T] =
8998 let nPoints = points.shape[0 ]
9099 let nDims = points.shape[1 ]
91- let gridSize =
92- if gridSize > 0 :
93- gridSize
94- else :
95- max (int (round (pow (nPoints.float , 1 / nDims) / 2 )), 1 )
100+ let gridSize = calcGridSize (nPoints, nDims, gridSize)
96101 let delta = 1 / gridSize
97102 result = RbfGrid [T](gridSize: gridSize, gridDim: nDims, gridDelta: delta, indices: newSeq [seq [int ]](gridSize ^ nDims))
98103 for row in 0 ..< nPoints:
@@ -152,13 +157,20 @@ proc newRbf*[T](points: Tensor[float], values: Tensor[T], gridSize: int = 0, rbf
152157 # # epsilon: shape parameter. Default 1.
153158 assert points.shape[0 ] == values.shape[0 ]
154159 assert points.shape.len == 2 and values.shape.len == 2
160+
155161 let upperLimit = max (points, 0 )
156162 let lowerLimit = min (points, 0 )
157163 let limits = (upper: upperLimit, lower: lowerLimit)
158164 let scaledPoints = points.scalePoint (limits)
165+
166+ let nPoints = points.shape[0 ]
167+ let nDims = points.shape[1 ]
168+ let gridSize = calcGridSize (nPoints, nDims, gridSize)
169+
159170 let dataGrid = newRbfGrid (scaledPoints, values, gridSize)
160171 let patchPoints = dataGrid.constructMeshedPatches ()
161172 let nPatches = patchPoints.shape[0 ]
173+
162174 var patchRbfs: seq [RbfBaseType [T]] # = newTensor[RbfBaseType[T]](nPatches, 1)
163175 var patchIndices: seq [int ]
164176 for i in 0 ..< nPatches:
0 commit comments