Skip to content

Incompatibility with Zygote #268

@daniloefl

Description

@daniloefl

Dear ArrayFire developers,

it seems that ArrayFire.jl has a small compatibility issue with Zygote.

A test example follows in [1], but the core of the issue is at the fact that Zygote implements a function Zygote.accum, which just sums up gradients, and for AbstractArrays, it is defined as follows in [2]. It basically uses broadcasting to call itself, assuming it would call the non-AbstractArray-typed version of itself. Unfortunately, the ArrayFire broadcasting calls the same function with arguments still as AbstractArray, causing an endless loop.

The solution could be a simple override of this function for AFArray:

Zygote.accum(x::AFArray, y::AFArray) =
         x === nothing ? y :
         y === nothing ? x :
         x .+ y

With this override, it all works. I am not sure if other overrides are necessary in more general cases, though. Although the Zygote developers could be summoned here, this would create a dependency between Zygote and ArrayFire, which is not really necessary. I am not sure that there is a cleaner way of solving the issue.

Best regards,
Danilo

[1]
Test example:

using ArrayFire
using Flux
using DiffEqFlux
using Zygote

hyper = FastChain(FastDense(1, 10, tanh), FastDense(10, 10, tanh), FastDense(10, 16, tanh))
p = initial_params(hyper)
x = rand(Float32, 1, 100)

# This is require due to a separate indexing issue in DiffEqFlux (unrelated to this bug, not doing this override causes a crash due to another incompatibility, but I daresay this is an issue in DiffEqFlux):
DiffEqFlux.applychain(fs::Tuple, x, p) = DiffEqFlux.applychain(Base.tail(fs), first(fs)(x,p[1:DiffEqFlux.paramlength(first(fs))]), length(fs) > 1 ? p[(DiffEqFlux.paramlength(first(fs))+1):end] : Tuple{}())

af_p = AFArray(p)

# this works:
hyper(x, af_p)

# this does not
gs = Flux.gradient(params(af_p)) do
         sum(hyper(x, af_p))
         end

The error I get is:

julia> gs = Flux.gradient(params(af_p)) do
                sum(hyper(x, af_p))
                end
ERROR: StackOverflowError:
Stacktrace:
 [1] broadcasted(::Function, ::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/ArrayFire/U0hth/src/array.jl:217
 [2] accum(::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/lib/lib.jl:16
 [3] broadcasted(::Function, ::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/ArrayFire/U0hth/src/array.jl:220
 ... (the last 2 lines are repeated 16335 more times)
 [32674] accum(::AFArray{Float32,1}, ::AFArray{Float32,1}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/lib/lib.jl:16
 [32675] applychain at ./REPL[7]:2 [inlined]
 [32676] (::typeof(∂(applychain)))(::FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 ... (the last 2 lines are repeated 1 more time)
 [32679] FastChain at /home/daniloefl/.julia/packages/DiffEqFlux/8UHw5/src/fast_layers.jl:21 [inlined]
 [32680] (::typeof(∂(λ)))(::FillArrays.Fill{Float32,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [32681] #5 at ./REPL[16]:2 [inlined]
 [32682] (::typeof(∂(#5)))(::Float32) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [32683] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(#5))})(::Float32) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [32684] gradient(::Function, ::Zygote.Params) at /home/daniloefl/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54

[2]
From Zygote.jl/src/lib/lib.jl:

accum() = nothing
accum(x) = x

accum(x, y) =
  x === nothing ? y :
  y === nothing ? x :
  x + y

accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions