-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
The following code fails
julia> struct MyStruct{T, VT<:AbstractVector{T}}
data::VT
end
julia> test(x::MyStruct) = sum(x.data)
test (generic function with 1 method)
julia> x = MyStruct(rand(10))
MyStruct{Float64, Vector{Float64}}([0.9356372116439775, 0.8146232142268686, 0.8646666386255512, 0.15223146438404445, 0.5357206445516378, 0.12556851518402812, 0.014343067486348171, 0.6245919019686935, 0.9909782907741145, 0.3270593445135602])
julia> test(x)
5.385420293358824
julia> dx = MyStruct(Reactant.ConcreteRArray(x.data))
MyStruct{Float64, ConcretePJRTArray{Float64, 1, 1}}(ConcretePJRTArray{Float64, 1, 1}([0.9356372116439775, 0.8146232142268686, 0.8646666386255512, 0.15223146438404445, 0.5357206445516378, 0.12556851518402812, 0.014343067486348171, 0.6245919019686935, 0.9909782907741145, 0.3270593445135602]))
julia> test_compiled = @compile test(dx)
ERROR: TypeError: in MyStruct, in VT, expected VT<:AbstractVector{Float64}, got Type{Reactant.TracedRArray{Float64, 1}}
Stacktrace:
[1] traced_type_inner(T::Type, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/Tracing.jl:750
[2] traced_type(T::Type, ::Val{Reactant.ConcreteToTraced}, track_numbers::Type, sharding::Reactant.Sharding.NoSharding, runtime::Val{:PJRT})
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/Tracing.jl:886
[3] make_tracer_unknown(seen::Reactant.OrderedIdDict{…}, prev::Any, path::Any, mode::Reactant.TraceMode; track_numbers::Type, sharding::Any, runtime::Any, kwargs::@Kwargs{…})
@ Reactant ~/.julia/packages/Reactant/8rzTQ/src/Tracing.jl:1048
[4] make_tracer_unknown
@ ~/.julia/packages/Reactant/8rzTQ/src/Tracing.jl:1025 [inlined]
[5] #make_tracer#137
@ ~/.julia/packages/Reactant/8rzTQ/src/Tracing.jl:1162 [inlined]
[6] make_tracer
@ ~/.julia/packages/Reactant/8rzTQ/src/Tracing.jl:1152 [inlined]
[7] prepare_mlir_fn_args(args::Tuple{…}, name::String, concretein::Bool, toscalar::Bool, argprefix::Symbol, runtime::Val{…}, optimize_then_pad::Bool, do_transpose::Bool, input_shardings::Nothing, verify_arg_names::Nothing)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/8rzTQ/src/TracedUtils.jl:450
[8] make_mlir_fn(f::typeof(test), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/8rzTQ/src/TracedUtils.jl:321
[9] make_mlir_fn
@ ~/.julia/packages/Reactant/8rzTQ/src/TracedUtils.jl:275 [inlined]
[10] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(test), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:1605
[11] compile_mlir!
@ ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:1567 [inlined]
[12] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:3513
[13] compile_xla
@ ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:3485 [inlined]
[14] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:3589
[15] top-level scope
@ ~/.julia/packages/Reactant/8rzTQ/src/Compiler.jl:2658
Some type information was truncated. Use `show(err)` to see complete types.Metadata
Metadata
Assignees
Labels
No labels