From 6ec2575b525704869984b942877c4b032bb8d4c4 Mon Sep 17 00:00:00 2001 From: gszep Date: Mon, 4 Nov 2019 09:44:19 +0000 Subject: [PATCH 1/6] scalar arithmetic working. vector not working --- src/Tracker.jl | 5 +- src/lib/array.jl | 5 +- src/lib/complex.jl | 111 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 3 deletions(-) create mode 100644 src/lib/complex.jl diff --git a/src/Tracker.jl b/src/Tracker.jl index adceea61..abee665c 100644 --- a/src/Tracker.jl +++ b/src/Tracker.jl @@ -66,6 +66,7 @@ include("params.jl") include("back.jl") include("numeric.jl") include("lib/real.jl") +include("lib/complex.jl") include("lib/array.jl") include("forward.jl") @@ -100,11 +101,13 @@ nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) @grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f") @grad nobacksies(f::String, x) = data(x), Δ -> error(f) -param(x::Number) = TrackedReal(float(x)) +param(x::Real) = TrackedReal(float(x)) +param(x::Complex) = TrackedComplex(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) @grad identity(x) = data(x), Δ -> (Δ,) param(x::TrackedReal) = track(identity, x) +param(x::TrackedComplex) = track(identity, x) param(x::TrackedArray) = track(identity, x) import Adapt: adapt, adapt_structure diff --git a/src/lib/array.jl b/src/lib/array.jl index e5a36d30..a2d0912b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -32,6 +32,7 @@ TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x)) Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} +Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Complex = TrackedComplex{T} Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x @@ -171,7 +172,7 @@ end for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat] cnames = map(_ -> gensym(), c) - @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) = + @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal,TrackedComplex}, xs::Union{AbstractArray,Number}...) = track($f, $(cnames...), x, xs...) end @@ -522,7 +523,7 @@ using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted struct TrackedStyle <: BroadcastStyle end -Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle() +Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal,TrackedComplex}}) = TrackedStyle() Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle() # We have to re-build the original broadcast struct to get the appropriate array diff --git a/src/lib/complex.jl b/src/lib/complex.jl new file mode 100644 index 00000000..caf2c22f --- /dev/null +++ b/src/lib/complex.jl @@ -0,0 +1,111 @@ +mutable struct TrackedComplex{T<:Complex} <: Real + data::T + tracker::Tracked{T} +end + +TrackedComplex(x::Complex) = TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x))) +TrackedComplex(x::Real) = TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x))) + +data(x::TrackedComplex) = x.data +tracker(x::TrackedComplex) = x.tracker + +track(f::Call, x::Complex) = TrackedComplex(x, Tracked{typeof(x)}(f, zero(x))) + +function back!(x::TrackedComplex; once = true) + isinf(x) && error("Loss is Inf") + isnan(x) && error("Loss is NaN") + return back!(x, 1, once = once) +end + +function update!(x::TrackedComplex, Δ) + x.data += data(Δ) + tracker(x).grad = 0 + return x +end + +function Base.show(io::IO, x::TrackedComplex) + T = get(io, :typeinfo, Any) + show(io, data(x)) + T <: TrackedComplex || print(io, " (tracked)") +end + +Base.decompose(x::TrackedComplex) = Base.decompose(data(x)) + +Base.copy(x::TrackedComplex) = x + +Base.convert(::Type{TrackedComplex{T}}, x::TrackedComplex{T}) where T = x + +Base.convert(::Type{TrackedComplex{T}}, x::Complex) where T = TrackedComplex(convert(T, x)) +Base.convert(::Type{TrackedComplex{T}}, x::Real) where T = TrackedComplex(convert(T, x)) + +Base.convert(::Type{TrackedComplex{T}}, x::TrackedComplex{S}) where {T,S} = + error("Not implemented: convert tracked $S to tracked $T") + +(T::Type{<:TrackedComplex})(x::Complex) = convert(T, x) + +for op in [:(==), :≈, :<, :(<=)] + @eval Base.$op(x::TrackedComplex, y::Complex) = Base.$op(data(x), y) + @eval Base.$op(x::Complex, y::TrackedComplex) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedComplex, y::TrackedComplex) = Base.$op(data(x), data(y)) +end + +Base.eps(x::TrackedComplex) = eps(data(x)) +Base.eps(::Type{TrackedComplex{T}}) where T = eps(T) + +for f in :[isinf, isnan, isfinite].args + @eval Base.$f(x::TrackedComplex) = Base.$f(data(x)) +end + +Base.Printf.fix_dec(x::TrackedComplex, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...) + +Base.float(x::TrackedComplex) = x + +Base.promote_rule(::Type{TrackedComplex{S}},::Type{T}) where {S,T} = + TrackedComplex{promote_type(S,T)} + +using Random + +for f in :[rand, randn, randexp].args + @eval Random.$f(rng::AbstractRNG,::Type{TrackedComplex{T}}) where {T} = param(rand(rng,T)) +end + +using DiffRules, SpecialFunctions, NaNMath + +for (M, f, arity) in DiffRules.diffrules() + arity == 1 || continue + @eval begin + @grad $M.$f(a::Complex) = + $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) + $M.$f(a::TrackedComplex) = track($M.$f, a) + end +end + + +for (M, f, arity) in DiffRules.diffrules() + arity == 2 || continue + da, db = DiffRules.diffrule(M, f, :a, :b) + f = :($M.$f) + @eval begin + @grad $f(a::TrackedComplex, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + + @grad $f(a::TrackedComplex, b::Complex) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) + @grad $f(a::TrackedComplex, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) + + @grad $f(a::Complex, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) + @grad $f(a::Real, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) + + $f(a::TrackedComplex, b::TrackedComplex) = track($f, a, b) + + $f(a::TrackedComplex, b::Complex) = track($f, a, b) + $f(a::TrackedComplex, b::Real) = track($f, a, b) + + $f(a::Complex, b::TrackedComplex) = track($f, a, b) + $f(a::Real, b::TrackedComplex) = track($f, a, b) + end +end + +using ForwardDiff: Dual +import Base:^ + +^(a::TrackedComplex, b::Integer) = track(^, a, b) +(T::Type{<:Complex})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) \ No newline at end of file From fda6895342d394a62a6433be542e867e70ac510e Mon Sep 17 00:00:00 2001 From: gszep Date: Mon, 4 Nov 2019 14:14:50 +0000 Subject: [PATCH 2/6] wip : vector operations still not working --- src/lib/complex.jl | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/lib/complex.jl b/src/lib/complex.jl index caf2c22f..1c3fdb58 100644 --- a/src/lib/complex.jl +++ b/src/lib/complex.jl @@ -1,10 +1,9 @@ -mutable struct TrackedComplex{T<:Complex} <: Real +mutable struct TrackedComplex{T<:Complex} # <: AbstractComplex data::T tracker::Tracked{T} end TrackedComplex(x::Complex) = TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x))) -TrackedComplex(x::Real) = TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x))) data(x::TrackedComplex) = x.data tracker(x::TrackedComplex) = x.tracker @@ -35,17 +34,16 @@ Base.copy(x::TrackedComplex) = x Base.convert(::Type{TrackedComplex{T}}, x::TrackedComplex{T}) where T = x -Base.convert(::Type{TrackedComplex{T}}, x::Complex) where T = TrackedComplex(convert(T, x)) -Base.convert(::Type{TrackedComplex{T}}, x::Real) where T = TrackedComplex(convert(T, x)) +Base.convert(::Type{TrackedComplex{T}}, x::Union{Complex,Real}) where T = TrackedComplex(convert(T, x)) Base.convert(::Type{TrackedComplex{T}}, x::TrackedComplex{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") -(T::Type{<:TrackedComplex})(x::Complex) = convert(T, x) +(T::Type{<:TrackedComplex})(x::Union{Complex,Real}) = convert(T, x) for op in [:(==), :≈, :<, :(<=)] - @eval Base.$op(x::TrackedComplex, y::Complex) = Base.$op(data(x), y) - @eval Base.$op(x::Complex, y::TrackedComplex) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedComplex, y::Union{Complex,Real}) = Base.$op(data(x), y) + @eval Base.$op(x::Union{Complex,Real}, y::TrackedComplex) = Base.$op(x, data(y)) @eval Base.$op(x::TrackedComplex, y::TrackedComplex) = Base.$op(data(x), data(y)) end @@ -80,32 +78,34 @@ for (M, f, arity) in DiffRules.diffrules() end end - for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - @grad $f(a::TrackedComplex, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - - @grad $f(a::TrackedComplex, b::Complex) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) - @grad $f(a::TrackedComplex, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) - @grad $f(a::Complex, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) - @grad $f(a::Real, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) + @grad $f(a::TrackedComplex, b::Union{TrackedComplex,TrackedReal}) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::Union{TrackedComplex,TrackedReal}, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - $f(a::TrackedComplex, b::TrackedComplex) = track($f, a, b) + @grad $f(a::TrackedComplex, b::Union{Complex,Real}) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) + @grad $f(a::Union{Complex,Real}, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) - $f(a::TrackedComplex, b::Complex) = track($f, a, b) - $f(a::TrackedComplex, b::Real) = track($f, a, b) + $f(a::TrackedComplex, b::Union{TrackedComplex,TrackedReal}) = track($f, a, b) + $f(a::Union{TrackedComplex,TrackedReal}, b::TrackedComplex) = track($f, a, b) - $f(a::Complex, b::TrackedComplex) = track($f, a, b) - $f(a::Real, b::TrackedComplex) = track($f, a, b) + $f(a::TrackedComplex, b::Union{Complex,Real}) = track($f, a, b) + $f(a::Union{Complex,Real}, b::TrackedComplex) = track($f, a, b) end end -using ForwardDiff: Dual +# Eliminating ambiguity, Hack for conversions import Base:^ +using ForwardDiff: Dual ^(a::TrackedComplex, b::Integer) = track(^, a, b) -(T::Type{<:Complex})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) \ No newline at end of file +(T::Type{<:Complex})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) + +# Array collection + +collectmemaybe(xs::AbstractArray{>:TrackedComplex}) = collect(xs) +collectmemaybe(xs::AbstractArray{<:TrackedComplex}) = collect(xs) \ No newline at end of file From 85d48ae214310d24b0a44c3dd80a268e44940f13 Mon Sep 17 00:00:00 2001 From: gszep Date: Mon, 4 Nov 2019 14:57:39 +0000 Subject: [PATCH 3/6] big fix : elementary functions on complex scalars --- src/lib/complex.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/complex.jl b/src/lib/complex.jl index 1c3fdb58..85b06ddd 100644 --- a/src/lib/complex.jl +++ b/src/lib/complex.jl @@ -72,9 +72,9 @@ using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin - @grad $M.$f(a::Complex) = + @grad $M.$f(a::Union{Complex,TrackedComplex}) = $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) - $M.$f(a::TrackedComplex) = track($M.$f, a) + $M.$f(a::Union{Complex,TrackedComplex}) = track($M.$f, a) end end From a2ffb86ff15c116077b734ce8e85ed4d0ec3eaec Mon Sep 17 00:00:00 2001 From: gszep Date: Mon, 4 Nov 2019 16:46:12 +0000 Subject: [PATCH 4/6] fix : nontracked Stackoverflow --- src/lib/complex.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lib/complex.jl b/src/lib/complex.jl index 85b06ddd..5d712752 100644 --- a/src/lib/complex.jl +++ b/src/lib/complex.jl @@ -72,9 +72,9 @@ using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin - @grad $M.$f(a::Union{Complex,TrackedComplex}) = + @grad $M.$f(a::TrackedComplex) = $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) - $M.$f(a::Union{Complex,TrackedComplex}) = track($M.$f, a) + $M.$f(a::TrackedComplex) = track($M.$f, a) end end @@ -84,14 +84,16 @@ for (M, f, arity) in DiffRules.diffrules() f = :($M.$f) @eval begin - @grad $f(a::TrackedComplex, b::Union{TrackedComplex,TrackedReal}) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - @grad $f(a::Union{TrackedComplex,TrackedReal}, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedComplex, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedComplex, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) @grad $f(a::TrackedComplex, b::Union{Complex,Real}) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) @grad $f(a::Union{Complex,Real}, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) - $f(a::TrackedComplex, b::Union{TrackedComplex,TrackedReal}) = track($f, a, b) - $f(a::Union{TrackedComplex,TrackedReal}, b::TrackedComplex) = track($f, a, b) + $f(a::TrackedComplex, b::TrackedComplex) = track($f, a, b) + $f(a::TrackedComplex, b::TrackedReal) = track($f, a, b) + $f(a::TrackedReal, b::TrackedComplex) = track($f, a, b) $f(a::TrackedComplex, b::Union{Complex,Real}) = track($f, a, b) $f(a::Union{Complex,Real}, b::TrackedComplex) = track($f, a, b) @@ -103,7 +105,7 @@ import Base:^ using ForwardDiff: Dual ^(a::TrackedComplex, b::Integer) = track(^, a, b) -(T::Type{<:Complex})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) +(T::Type{<:Union{Complex,TrackedComplex}})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) # Array collection From 1b1e1eb888a740ea23b32ed25716c1c3c3d959f8 Mon Sep 17 00:00:00 2001 From: gszep Date: Tue, 5 Nov 2019 02:12:07 +0000 Subject: [PATCH 5/6] bug fix : tracking in element-wise ops --- .gitignore | 2 ++ src/lib/array.jl | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index eb18605c..b4dea73a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ docs/build/ docs/site/ deps +.* +!.gitignore \ No newline at end of file diff --git a/src/lib/array.jl b/src/lib/array.jl index a2d0912b..f6e24780 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -499,7 +499,7 @@ unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Base.RefValue, _) = nothing dual(x, p) = x -dual(x::Real, p) = Dual(x, p) +dual(x::Union{Real,Complex}, p) = Dual(x, p) function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N} dargs = ntuple(j -> dual(args[j], i==j), Val(N)) @@ -508,7 +508,7 @@ end @inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N} y = broadcast(f, data.(args)...) - eltype(y) <: Real || return y + eltype(y) <: Union{Real,Complex} || return y eltype(y) == Bool && return y function back(Δ) Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N)) From 8af8985c0c15e405654c87a8839feaed3609b050 Mon Sep 17 00:00:00 2001 From: gszep Date: Wed, 13 Nov 2019 13:26:52 +0000 Subject: [PATCH 6/6] wip : gradients of f : C(N) -> R(1) not working --- src/forward.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forward.jl b/src/forward.jl index ccf75c70..d23fa93a 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -1,6 +1,6 @@ using ForwardDiff -seed(x::Real, ::Val) = Dual(x, true) +seed(x::Union{Real,Complex}, ::Val) = Dual(x, true) function seed(x, ::Val{N}, offset = 0) where N map(x, reshape(1:length(x), size(x))) do x, i