Skip to content

Commit 8e3458c

Browse files
test: more extensive testing for basic array operations
1 parent a43835a commit 8e3458c

File tree

1 file changed

+76
-4
lines changed

1 file changed

+76
-4
lines changed

test/arrays.jl

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using SymbolicUtils, Test
33
using Symbolics: symtype, shape, wrap, unwrap, Arr, jacobian, @variables, value, get_variables, @arrayop, getname, metadata, scalarize
44
using Base: Slice
55
using SymbolicUtils: Sym, term, operation, search_variables
6-
import LinearAlgebra: dot, Adjoint
6+
import SymbolicUtils.Code: toexpr
7+
import LinearAlgebra: dot, Adjoint, cross, diagm
78
import ..limit2
89

910
struct TestMetaT end
@@ -446,7 +447,78 @@ end
446447
@test operation(sym) === +
447448
end
448449

449-
@testset "matmul ambiguity" begin
450-
@variables x[1:3] y
451-
@test_nowarn x * y
450+
@testset "Basic functions test" begin
451+
@variables a b c
452+
@variables x[1:3] y[1:3, 1:3]
453+
arrays = [
454+
x,
455+
y,
456+
unwrap(x),
457+
unwrap(y),
458+
collect(y),
459+
[a, b, c],
460+
diagm([a, b, c]) * x,
461+
diagm([a, b, c]) * [a, b, c],
462+
y * [a, b, c],
463+
diagm([a, b, c]) * y,
464+
]
465+
466+
xval = rand(3)
467+
yval = rand(3, 3)
468+
setup(ex) = quote
469+
let a = 1.0, b = 2.3, c = 4.5, x = $xval, y = $yval
470+
$(ex)
471+
end
472+
end
473+
474+
@testset "$arr1 AND $arr2" for arr1 in arrays, arr2 in arrays
475+
val1 = eval(setup(toexpr(arr1)))
476+
val2 = eval(setup(toexpr(arr2)))
477+
if ndims(arr1) == ndims(arr2) == 1 && length(arr1) == length(arr2) == 3
478+
t1 = setup(toexpr(cross(arr1, arr2)))
479+
@test eval(t1) cross(val1, val2)
480+
end
481+
if size(arr1) == size(arr2)
482+
t1 = setup(toexpr(dot(arr1, arr2)))
483+
@test eval(t1) dot(val1, val2)
484+
485+
t1 = setup(toexpr(arr1 .* arr2))
486+
@test eval(t1) val1 .* val2
487+
488+
t1 = setup(toexpr(arr1 + arr2))
489+
@test eval(t1) val1 + val2
490+
end
491+
t1 = try
492+
setup(toexpr(arr1'arr2))
493+
catch
494+
nothing
495+
end
496+
if t1 !== nothing
497+
@test eval(t1) val1'val2
498+
end
499+
500+
truth = try
501+
val1 / val2
502+
catch
503+
nothing
504+
end
505+
if truth !== nothing
506+
t1 = setup(toexpr(arr1 / arr2))
507+
@test eval(t1) truth
508+
end
509+
510+
truth = try
511+
val1 \ val2
512+
catch
513+
nothing
514+
end
515+
t1 = try
516+
setup(toexpr(arr1 \ arr2))
517+
catch
518+
nothing
519+
end
520+
if truth !== nothing && t1 !== nothing
521+
@test eval(t1) truth
522+
end
523+
end
452524
end

0 commit comments

Comments
 (0)