11abstract type AbstractThunk <: AbstractTangent end
22
3+ struct MutateThunkException <: Exception end
4+
5+ function Base. showerror (io:: IO , e:: MutateThunkException )
6+ print (io, " Tried to mutate a thunk, this is not supported. `unthunk` it first." )
7+ return nothing
8+ end
9+
310Base. Broadcast. broadcastable (x:: AbstractThunk ) = broadcastable (unthunk (x))
411
512@inline function Base. iterate (x:: AbstractThunk )
@@ -19,6 +26,94 @@ Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
1926Base.:(== )(a:: AbstractThunk , b) = unthunk (a) == b
2027Base.:(== )(a, b:: AbstractThunk ) = a == unthunk (b)
2128
29+ Base.:(- )(a:: AbstractThunk ) = - unthunk (a)
30+ Base.:(- )(a:: AbstractThunk , b) = unthunk (a) - b
31+ Base.:(- )(a, b:: AbstractThunk ) = a - unthunk (b)
32+ Base.:(/ )(a:: AbstractThunk , b) = unthunk (a) / b
33+ Base.:(/ )(a, b:: AbstractThunk ) = a / unthunk (b)
34+
35+ Base. real (a:: AbstractThunk ) = real (unthunk (a))
36+ Base. imag (a:: AbstractThunk ) = imag (unthunk (a))
37+ Base. Complex (a:: AbstractThunk ) = Complex (unthunk (a))
38+ Base. Complex (a:: AbstractThunk , b:: AbstractThunk ) = Complex (unthunk (a), unthunk (b))
39+
40+ Base. mapreduce (f, op, a:: AbstractThunk ; kws... ) = mapreduce (f, op, unthunk (a); kws... )
41+ function Base. mapreduce (f, op, itr, a:: AbstractThunk ; kws... )
42+ return mapreduce (f, op, itr, unthunk (a); kws... )
43+ end
44+ Base. sum! (r, A:: AbstractThunk ; kws... ) = sum! (r, unthunk (A); kws... )
45+
46+ Base. vec (a:: AbstractThunk ) = vec (unthunk (a))
47+ Base. reshape (a:: AbstractThunk , args... ) = reshape (unthunk (a), args... )
48+ Base. getindex (a:: AbstractThunk , args... ) = getindex (unthunk (a), args... )
49+ Base. setindex! (a:: AbstractThunk , value, key... ) = throw (MutateThunkException ())
50+ Base. selectdim (a:: AbstractThunk , args... ) = selectdim (unthunk (a), args... )
51+
52+ LinearAlgebra. Array (a:: AbstractThunk ) = Array (unthunk (a))
53+ LinearAlgebra. Matrix (a:: AbstractThunk ) = Matrix (unthunk (a))
54+ LinearAlgebra. Diagonal (a:: AbstractThunk ) = Diagonal (unthunk (a))
55+ LinearAlgebra. LowerTriangular (a:: AbstractThunk ) = LowerTriangular (unthunk (a))
56+ LinearAlgebra. UpperTriangular (a:: AbstractThunk ) = UpperTriangular (unthunk (a))
57+ LinearAlgebra. Symmetric (a:: AbstractThunk , uplo= :U ) = Symmetric (unthunk (a), uplo)
58+ LinearAlgebra. Hermitian (a:: AbstractThunk , uplo= :U ) = Hermitian (unthunk (a), uplo)
59+
60+ function LinearAlgebra. diagm (kv:: Pair{<:Integer,<:AbstractThunk} ...)
61+ return diagm ((k => unthunk (v) for (k, v) in kv). .. )
62+ end
63+ function LinearAlgebra. diagm (m, n, kv:: Pair{<:Integer,<:AbstractThunk} ...)
64+ return diagm (m, n, (k => unthunk (v) for (k, v) in kv). .. )
65+ end
66+ LinearAlgebra. tril (a:: AbstractThunk ) = tril (unthunk (a))
67+ LinearAlgebra. tril (a:: AbstractThunk , k) = tril (unthunk (a), k)
68+ LinearAlgebra. triu (a:: AbstractThunk ) = triu (unthunk (a))
69+ LinearAlgebra. triu (a:: AbstractThunk , k) = triu (unthunk (a), k)
70+ LinearAlgebra. tr (a:: AbstractThunk ) = tr (unthunk (a))
71+ LinearAlgebra. cross (a:: AbstractThunk , b) = cross (unthunk (a), b)
72+ LinearAlgebra. cross (a, b:: AbstractThunk ) = cross (a, unthunk (b))
73+ LinearAlgebra. cross (a:: AbstractThunk , b:: AbstractThunk ) = cross (unthunk (a), unthunk (b))
74+ LinearAlgebra. dot (a:: AbstractThunk , b) = dot (unthunk (a), b)
75+ LinearAlgebra. dot (a, b:: AbstractThunk ) = dot (a, unthunk (b))
76+ LinearAlgebra. dot (a:: AbstractThunk , b:: AbstractThunk ) = dot (unthunk (a), unthunk (b))
77+
78+ LinearAlgebra. ldiv! (a, b:: AbstractThunk ) = throw (MutateThunkException ())
79+ LinearAlgebra. rdiv! (a:: AbstractThunk , b) = throw (MutateThunkException ())
80+
81+ LinearAlgebra. mul! (C:: AbstractThunk , A, B, α, β) = throw (MutateThunkException ())
82+ function LinearAlgebra. mul! (C:: AbstractThunk , A:: AbstractThunk , B, α, β)
83+ return throw (MutateThunkException ())
84+ end
85+ function LinearAlgebra. mul! (C:: AbstractThunk , A, B:: AbstractThunk , α, β)
86+ return throw (MutateThunkException ())
87+ end
88+ function LinearAlgebra. mul! (C:: AbstractThunk , A:: AbstractThunk , B:: AbstractThunk , α, β)
89+ return throw (MutateThunkException ())
90+ end
91+ LinearAlgebra. mul! (C, A:: AbstractThunk , B, α, β) = mul! (C, unthunk (A), B, α, β)
92+ LinearAlgebra. mul! (C, A, B:: AbstractThunk , α, β) = mul! (C, A, unthunk (B), α, β)
93+ function LinearAlgebra. mul! (C, A:: AbstractThunk , B:: AbstractThunk , α, β)
94+ return mul! (C, unthunk (A), unthunk (B), α, β)
95+ end
96+
97+ function LinearAlgebra. BLAS. ger! (alpha, x:: AbstractThunk , y, A)
98+ return LinearAlgebra. BLAS. ger! (alpha, unthunk (x), y, A)
99+ end
100+ function LinearAlgebra. BLAS. ger! (alpha, x, y:: AbstractThunk , A)
101+ return LinearAlgebra. BLAS. ger! (alpha, x, unthunk (y), A)
102+ end
103+ function LinearAlgebra. BLAS. gemv! (tA, alpha, A, x:: AbstractThunk , beta, y)
104+ return LinearAlgebra. BLAS. gemv! (tA, alpha, A, unthunk (x), beta, y)
105+ end
106+ function LinearAlgebra. BLAS. gemv (tA, alpha, A, x:: AbstractThunk )
107+ return LinearAlgebra. BLAS. gemv (tA, alpha, A, unthunk (x))
108+ end
109+ function LinearAlgebra. BLAS. scal! (n, a:: AbstractThunk , X, incx)
110+ return LinearAlgebra. BLAS. scal! (n, unthunk (a), X, incx)
111+ end
112+
113+ function LinearAlgebra. LAPACK. trsyl! (transa, transb, A, B, C:: AbstractThunk , isgn= 1 )
114+ return throw (MutateThunkException ())
115+ end
116+
22117"""
23118 @thunk expr
24119
@@ -109,7 +204,7 @@ but it should do this more efficently than simply doing this directly.
109204Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
110205and destroy its inplacability.
111206"""
112- struct InplaceableThunk{T<: Thunk , F} <: AbstractThunk
207+ struct InplaceableThunk{T<: Thunk ,F} <: AbstractThunk
113208 val:: T
114209 add!:: F
115210end
@@ -118,5 +213,5 @@ unthunk(x::InplaceableThunk) = unthunk(x.val)
118213(x:: InplaceableThunk )() = unthunk (x)
119214
120215function Base. show (io:: IO , x:: InplaceableThunk )
121- print (io, " InplaceableThunk($(repr (x. val)) , $(repr (x. add!)) )" )
216+ return print (io, " InplaceableThunk($(repr (x. val)) , $(repr (x. add!)) )" )
122217end
0 commit comments