Skip to content

Commit 663438f

Browse files
committed
implement and test general meshgrid
1 parent fbcdea7 commit 663438f

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/numericalnim/utils.nim

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,31 @@ proc meshgridFlat*[T](x, y: Tensor[T]): (Tensor[T], Tensor[T]) =
410410
result[0][i+j*nx] = x[i]
411411
result[1][i+j*nx] = y[j]
412412

413+
proc meshgridInternal[T](x1, x2: Tensor[T]): Tensor[T] =
414+
assert x2.squeeze().shape.len == 1
415+
assert x1.shape.len in [1, 2]
416+
let x1 =
417+
if x1.shape.len == 2:
418+
x1
419+
else:
420+
x1.unsqueeze(1)
421+
let len1 = x1.shape[0]
422+
let cols1 = x1.shape[1]
423+
let len2 = x2.shape[0]
424+
result = newTensor[T](len1 * len2, cols1 + 1)
425+
for i in 0 ..< len2:
426+
result[i*len1 ..< (i+1)*len1, 0 ..< cols1] = x1
427+
result[i*len1 ..< (i+1)*len1, ^1] = x2[i]
428+
429+
proc meshgrid*[T](ts: varargs[Tensor[T]]): Tensor[T] =
430+
if ts.len == 1:
431+
result = ts[0]
432+
elif ts.len == 0:
433+
assert false, "No input was given to meshgrid!"
434+
else:
435+
result = ts[0]
436+
for x in ts[1..^1]:
437+
result = meshgridInternal(result, x)
413438

414439
proc isClose*[T](y1, y2: T, tol: float = 1e-3): bool {.inline.} =
415440
let diff = calcError(y1, y2)

tests/test_utils.nim

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,10 @@ test "meshgridFlat":
8484
check gridX == [0, 1, 2, 0, 1, 2, 0, 1, 2].toTensor
8585
check gridY == [3, 3, 3, 4, 4, 4, 5, 5, 5].toTensor
8686

87+
test "meshgrid":
88+
let x = [0, 1].toTensor
89+
let y = [2, 3].toTensor
90+
let z = [4, 5].toTensor
91+
let grid = meshgrid(x, y, z)
92+
check grid == [[0, 2, 4], [1, 2, 4], [0, 3, 4], [1, 3, 4], [0, 2, 5], [1, 2, 5], [0, 3, 5], [1, 3, 5]].toTensor
93+

0 commit comments

Comments
 (0)