Skip to content

Commit 3875b9d

Browse files
Merge pull request #1686 from JuliaSymbolics/as/more-stuff
fix: fix some method ambiguties, improve testing
2 parents 541c4dc + 8e3458c commit 3875b9d

File tree

3 files changed

+102
-7
lines changed

3 files changed

+102
-7
lines changed

src/array-lib.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function wrapper_fn_from_idxs(x::Arr{T, N}, idxs...) where {T, N}
1111
nd = _indexed_ndims(idxs...)
1212
return nd == 0 ? is_wrapper_type(T) ? T : identity : Arr{T, nd}
1313
end
14+
wrapper_fn_from_idxs(x::Arr{T, N}, idx::SymbolicUtils.StableIndex{Int}) where {T, N} = T
1415
# Wrapped array should wrap the elements too
1516
function Base.getindex(x::Arr{T, N}, idx::CartesianIndex{N}) where {T, N}
1617
if is_wrapper_type(T)
@@ -80,10 +81,10 @@ end
8081
function *(a::PolyadicT, b::Arr, bs::PolyadicT...)
8182
return *(a, unwrap(b), bs...)
8283
end
83-
function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, bs::PolyadicT...) where {T}
84+
function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, bs::PolyadicT...) where {T <: Number}
8485
return *(a, unwrap(b), bs...)
8586
end
86-
function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, c::AbstractVector, bs::PolyadicT...) where {T}
87+
function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, c::AbstractVector, bs::PolyadicT...) where {T <: Number}
8788
return *(a, unwrap(b), unwrap(c), bs...)
8889
end
8990
function *(a::Number, b::Arr, bs::PolyadicT...)
@@ -121,6 +122,20 @@ function +(x1::Arr, x2::Arr, args::AbstractArray...)
121122
return +(unwrap(x1), unwrap(x2), args...)
122123
end
123124

125+
for T1 in [Arr, AbstractArray], T2 in [Arr, AbstractArray]
126+
T1 == T2 == AbstractArray && continue
127+
@eval Base.:(\)(x1::$T1{Num, 1}, x2::$T2{Num, 1}) = Num(unwrap(x1) \ unwrap(x2))
128+
@eval Base.:(\)(x1::$T1{Num, 1}, x2::$T2{Num, 2}) = Arr{Num, 2}(unwrap(x1) \ unwrap(x2))
129+
@eval Base.:(\)(x1::$T1{Num, 2}, x2::$T2{Num, 1}) = Arr{Num, 1}(unwrap(x1) \ unwrap(x2))
130+
@eval Base.:(\)(x1::$T1{Num, 2}, x2::$T2{Num, 2}) = Arr{Num, 2}(unwrap(x1) \ unwrap(x2))
131+
132+
@eval Base.:(/)(x1::$T1{Num, 1}, x2::$T2{Num, 1}) = Arr{Num, 2}(unwrap(x1) / unwrap(x2))
133+
@eval Base.:(/)(x1::$T1{Num, 1}, x2::$T2{Num, 2}) = Arr{Num, 2}(unwrap(x1) / unwrap(x2))
134+
@eval Base.:(/)(x1::$T1{Num, 2}, x2::$T2{Num, 2}) = Arr{Num, 2}(unwrap(x1) / unwrap(x2))
135+
end
136+
137+
Base.:(/)(x1::Num, x2::Arr{Num, 1}) = Arr{Num, 2}(unwrap(x1) / unwrap(x2))
138+
124139
#################### MAP-REDUCE ################
125140

126141
SymbolicUtils.@map_methods Arr unwrap wrap

src/num.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ for (T1, T2) in Iterators.product([Num, Integer], [Num, Integer])
6565
end
6666
end
6767

68-
for f in [/, \, ^]
68+
for f in [\, ^]
6969
@eval function (::$(typeof(f)))(x1::AbstractArray{<:Real}, x2::Num)
7070
$f(x1, unwrap(x2))
7171
end
@@ -75,6 +75,14 @@ for f in [/, \, ^]
7575
end
7676
end
7777

78+
function Base.:(/)(x1::AbstractArray{<:Real}, x2::Num)
79+
/(unwrap(x1), unwrap(x2))
80+
end
81+
82+
function Base.:(/)(x1::Num, x2::AbstractVector{<:Real})
83+
/(unwrap(x1), unwrap(x2))
84+
end
85+
7886
Base.conj(x::Num) = x
7987
Base.transpose(x::Num) = x
8088

test/arrays.jl

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using SymbolicUtils, Test
33
using Symbolics: symtype, shape, wrap, unwrap, Arr, jacobian, @variables, value, get_variables, @arrayop, getname, metadata, scalarize
44
using Base: Slice
55
using SymbolicUtils: Sym, term, operation, search_variables
6-
import LinearAlgebra: dot, Adjoint
6+
import SymbolicUtils.Code: toexpr
7+
import LinearAlgebra: dot, Adjoint, cross, diagm
78
import ..limit2
89

910
struct TestMetaT end
@@ -446,7 +447,78 @@ end
446447
@test operation(sym) === +
447448
end
448449

449-
@testset "matmul ambiguity" begin
450-
@variables x[1:3] y
451-
@test_nowarn x * y
450+
@testset "Basic functions test" begin
451+
@variables a b c
452+
@variables x[1:3] y[1:3, 1:3]
453+
arrays = [
454+
x,
455+
y,
456+
unwrap(x),
457+
unwrap(y),
458+
collect(y),
459+
[a, b, c],
460+
diagm([a, b, c]) * x,
461+
diagm([a, b, c]) * [a, b, c],
462+
y * [a, b, c],
463+
diagm([a, b, c]) * y,
464+
]
465+
466+
xval = rand(3)
467+
yval = rand(3, 3)
468+
setup(ex) = quote
469+
let a = 1.0, b = 2.3, c = 4.5, x = $xval, y = $yval
470+
$(ex)
471+
end
472+
end
473+
474+
@testset "$arr1 AND $arr2" for arr1 in arrays, arr2 in arrays
475+
val1 = eval(setup(toexpr(arr1)))
476+
val2 = eval(setup(toexpr(arr2)))
477+
if ndims(arr1) == ndims(arr2) == 1 && length(arr1) == length(arr2) == 3
478+
t1 = setup(toexpr(cross(arr1, arr2)))
479+
@test eval(t1) cross(val1, val2)
480+
end
481+
if size(arr1) == size(arr2)
482+
t1 = setup(toexpr(dot(arr1, arr2)))
483+
@test eval(t1) dot(val1, val2)
484+
485+
t1 = setup(toexpr(arr1 .* arr2))
486+
@test eval(t1) val1 .* val2
487+
488+
t1 = setup(toexpr(arr1 + arr2))
489+
@test eval(t1) val1 + val2
490+
end
491+
t1 = try
492+
setup(toexpr(arr1'arr2))
493+
catch
494+
nothing
495+
end
496+
if t1 !== nothing
497+
@test eval(t1) val1'val2
498+
end
499+
500+
truth = try
501+
val1 / val2
502+
catch
503+
nothing
504+
end
505+
if truth !== nothing
506+
t1 = setup(toexpr(arr1 / arr2))
507+
@test eval(t1) truth
508+
end
509+
510+
truth = try
511+
val1 \ val2
512+
catch
513+
nothing
514+
end
515+
t1 = try
516+
setup(toexpr(arr1 \ arr2))
517+
catch
518+
nothing
519+
end
520+
if truth !== nothing && t1 !== nothing
521+
@test eval(t1) truth
522+
end
523+
end
452524
end

0 commit comments

Comments
 (0)