Skip to content

Commit cdd2f62

Browse files
authored
optimizer: lift more comparisons (#43227)
This commit implements more comparison liftings. Especially, this change enables the compiler to lift `isa`/`isdefined` checks (i.e. replace a comparison call with ϕ-node by CFG union-splitting). For example, the code snippet below will run 500x faster: ```julia function compute(n) s = 0 itr = 1:n st = iterate(itr) while isdefined(st, 2) # mimic our iteration protocol with `isdefined` v, st = st s += v st = iterate(itr, st) end s end ``` Although it seems like the codegen for `isa` is fairly optimized already and so I could not find any performance benefit for `isa`-lifting (`code_llvm` emits mostly equivalent code), but I hope it's more ideal if we can do the equivalent optimization on Julia level so that we can just consult to `code_typed` for performance optimization.
1 parent 2682819 commit cdd2f62

File tree

2 files changed

+180
-87
lines changed

2 files changed

+180
-87
lines changed

base/compiler/ssair/passes.jl

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -471,19 +471,21 @@ end
471471
make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ)
472472

473473
"""
474-
lift_comparison!(compact::IncrementalCompact, idx::Int, stmt::Expr)
474+
lift_comparison!(cmp, compact::IncrementalCompact, idx::Int, stmt::Expr)
475475
476-
Replaces `φ(x, y)::Union{X,Y} === constant` by `φ(x === constant, y === constant)`,
477-
where `x === constant` and `y === constant` can be replaced with constant `Bool`eans.
478-
It helps codegen avoid generating expensive code for `===` with `Union` types.
476+
Replaces `cmp(φ(x, y)::Union{X,Y}, constant)` by `φ(cmp(x, constant), cmp(y, constant))`,
477+
where `cmp(x, constant)` and `cmp(y, constant)` can be replaced with constant `Bool`eans.
478+
It helps codegen avoid generating expensive code for `cmp` with `Union` types.
479479
In particular, this is supposed to improve the performance of the iteration protocol:
480480
```julia
481481
while x !== nothing
482482
x = iterate(...)::Union{Nothing,Tuple{Any,Any}}
483483
end
484484
```
485485
"""
486-
function lift_comparison!(compact::IncrementalCompact,
486+
function lift_comparison! end
487+
488+
function lift_comparison!(::typeof(===), compact::IncrementalCompact,
487489
idx::Int, stmt::Expr, lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue})
488490
args = stmt.args
489491
length(args) == 3 || return
@@ -493,37 +495,59 @@ function lift_comparison!(compact::IncrementalCompact,
493495
vr = argextype(rhs, compact)
494496
if isa(vl, Const)
495497
isa(vr, Const) && return
496-
cmp = vl
497-
typeconstraint = widenconst(vr)
498498
val = rhs
499+
target = lhs
499500
elseif isa(vr, Const)
500-
cmp = vr
501-
typeconstraint = widenconst(vl)
502501
val = lhs
502+
target = rhs
503503
else
504504
return
505505
end
506506

507-
valtyp = widenconst(argextype(val, compact))
508-
isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting
507+
lift_comparison_leaves!(egal_tfunc, compact, val, target, lifting_cache, idx)
508+
end
509509

510-
leaves, visited_phinodes = collect_leaves(compact, val, valtyp)
510+
function lift_comparison!(::typeof(isa), compact::IncrementalCompact,
511+
idx::Int, stmt::Expr, lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue})
512+
args = stmt.args
513+
length(args) == 3 || return
514+
lift_comparison_leaves!(isa_tfunc, compact, args[2], args[3], lifting_cache, idx)
515+
end
516+
517+
function lift_comparison!(::typeof(isdefined), compact::IncrementalCompact,
518+
idx::Int, stmt::Expr, lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue})
519+
args = stmt.args
520+
length(args) == 3 || return
521+
lift_comparison_leaves!(isdefined_tfunc, compact, args[2], args[3], lifting_cache, idx)
522+
end
523+
524+
function lift_comparison_leaves!(@specialize(tfunc),
525+
compact::IncrementalCompact, @nospecialize(val), @nospecialize(target),
526+
lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, idx::Int)
527+
typeconstraint = widenconst(argextype(val, compact))
528+
if isa(val, Union{OldSSAValue, SSAValue})
529+
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
530+
end
531+
isa(typeconstraint, Union) || return # bail out if there won't be a good chance for lifting
532+
leaves, visited_phinodes = collect_leaves(compact, val, typeconstraint)
511533
length(leaves) 1 && return # bail out if we don't have multiple leaves
512534

513-
# Let's check if we evaluate the comparison for each one of the leaves
535+
# check if we can evaluate the comparison for each one of the leaves
536+
cmp = argextype(target, compact)
514537
lifted_leaves = nothing
515538
for leaf in leaves
516-
r = egal_tfunc(argextype(leaf, compact), cmp)
517-
if isa(r, Const)
539+
result = tfunc(argextype(leaf, compact), cmp)
540+
if isa(result, Const)
518541
if lifted_leaves === nothing
519542
lifted_leaves = LiftedLeaves()
520543
end
521-
lifted_leaves[leaf] = LiftedValue(r.val)
544+
lifted_leaves[leaf] = LiftedValue(result.val)
522545
else
523-
return # TODO In some cases it might be profitable to hoist the === here
546+
return # TODO In some cases it might be profitable to hoist the comparison here
524547
end
525548
end
526549

550+
# perform lifting
527551
lifted_val = perform_lifting!(compact,
528552
visited_phinodes, cmp, lifting_cache, Bool,
529553
lifted_leaves::LiftedLeaves, val)::LiftedValue
@@ -715,10 +739,14 @@ function sroa_pass!(ir::IRCode)
715739
canonicalize_typeassert!(compact, idx, stmt)
716740
continue
717741
elseif is_known_call(stmt, (===), compact)
718-
lift_comparison!(compact, idx, stmt, lifting_cache)
742+
lift_comparison!(===, compact, idx, stmt, lifting_cache)
743+
continue
744+
elseif is_known_call(stmt, isa, compact)
745+
lift_comparison!(isa, compact, idx, stmt, lifting_cache)
746+
continue
747+
elseif is_known_call(stmt, isdefined, compact)
748+
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache)
719749
continue
720-
# elseif is_known_call(stmt, isa, compact)
721-
# TODO do a similar optimization as `lift_comparison!` for `===`
722750
else
723751
continue
724752
end

test/compiler/irpasses.jl

Lines changed: 132 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,34 @@ using Test
44
using Base.Meta
55
using Core: PhiNode, SSAValue, GotoNode, PiNode, QuoteNode, ReturnNode, GotoIfNot
66

7-
# Tests for domsort
7+
# utilities
8+
# =========
9+
10+
import Core.Compiler: argextype, singleton_type
11+
12+
argextype(@nospecialize args...) = argextype(args..., Any[])
13+
code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo
14+
get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code
15+
16+
# check if `x` is a statement with a given `head`
17+
isnew(@nospecialize x) = Meta.isexpr(x, :new)
18+
19+
# check if `x` is a dynamic call of a given function
20+
iscall(y) = @nospecialize(x) -> iscall(y, x)
21+
function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x))
22+
return iscall(x) do @nospecialize x
23+
singleton_type(argextype(x, src)) === f
24+
end
25+
end
26+
iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])
27+
28+
# check if `x` is a statically-resolved call of a function whose name is `sym`
29+
isinvoke(y) = @nospecialize(x) -> isinvoke(y, x)
30+
isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x)
31+
isinvoke(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :invoke) && pred(x.args[1]::Core.MethodInstance)
32+
33+
# domsort
34+
# =======
835

936
## Test that domsort doesn't mangle single-argument phis (#29262)
1037
let m = Meta.@lower 1 + 1
@@ -67,25 +94,8 @@ let m = Meta.@lower 1 + 1
6794
Core.Compiler.verify_ir(ir)
6895
end
6996

70-
# Tests for SROA
71-
72-
import Core.Compiler: argextype, singleton_type
73-
const EMPTY_SPTYPES = Any[]
74-
75-
code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo
76-
get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code
77-
78-
# check if `x` is a statement with a given `head`
79-
isnew(@nospecialize x) = Meta.isexpr(x, :new)
80-
81-
# check if `x` is a dynamic call of a given function
82-
iscall(y) = @nospecialize(x) -> iscall(y, x)
83-
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
84-
return iscall(x) do @nospecialize x
85-
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
86-
end
87-
end
88-
iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])
97+
# SROA
98+
# ====
8999

90100
struct ImmutableXYZ; x; y; z; end
91101
mutable struct MutableXYZ; x; y; z; end
@@ -277,6 +287,38 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
277287
@test_broken !any(isnew, src.code)
278288
end
279289

290+
let # should work with constant globals
291+
# immutable case
292+
# --------------
293+
src = @eval Module() begin
294+
const REF_FLD = :x
295+
struct ImmutableRef{T}
296+
x::T
297+
end
298+
299+
code_typed((Int,)) do x
300+
r = ImmutableRef{Int}(x) # should be eliminated
301+
x = getfield(r, REF_FLD) # should be eliminated
302+
return sin(x)
303+
end |> only |> first
304+
end
305+
@test count(iscall((src, getfield)), src.code) == 0
306+
@test count(isnew, src.code) == 0
307+
308+
# mutable case
309+
# ------------
310+
src = @eval Module() begin
311+
const REF_FLD = :x
312+
code_typed() do
313+
r = Ref{Int}(42) # should be eliminated
314+
x = getfield(r, REF_FLD) # should be eliminated
315+
return sin(x)
316+
end |> only |> first
317+
end
318+
@test count(iscall((src, getfield)), src.code) == 0
319+
@test count(isnew, src.code) == 0
320+
end
321+
280322
# should work nicely with inlining to optimize away a complicated case
281323
# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B
282324
struct Point
@@ -296,6 +338,75 @@ let src = code_typed1(compute_points)
296338
@test !any(isnew, src.code)
297339
end
298340

341+
# comparison lifting
342+
# ==================
343+
344+
let # lifting `===`
345+
src = code_typed1((Bool,Int,)) do c, x
346+
y = c ? x : nothing
347+
y === nothing # => ϕ(false, true)
348+
end
349+
@test count(iscall((src, ===)), src.code) == 0
350+
351+
# should optimize away the iteration protocol
352+
src = code_typed1((Int,)) do n
353+
s = 0
354+
for i in 1:n
355+
s += i
356+
end
357+
s
358+
end
359+
@test !any(src.code) do @nospecialize x
360+
iscall((src, ===), x) && argextype(x.args[2], src) isa Union
361+
end
362+
end
363+
364+
let # lifting `isa`
365+
src = code_typed1((Bool,Int,)) do c, x
366+
y = c ? x : nothing
367+
isa(y, Int) # => ϕ(true, false)
368+
end
369+
@test count(iscall((src, isa)), src.code) == 0
370+
371+
src = code_typed1((Int,)) do n
372+
s = 0
373+
itr = 1:n
374+
st = iterate(itr)
375+
while !isa(st, Nothing)
376+
i, st = itr
377+
s += i
378+
st = iterate(itr, st)
379+
end
380+
s
381+
end
382+
@test !any(src.code) do @nospecialize x
383+
iscall((src, isa), x) && argextype(x.args[2], src) isa Union
384+
end
385+
end
386+
387+
let # lifting `isdefined`
388+
src = code_typed1((Bool,Some{Int},)) do c, x
389+
y = c ? x : nothing
390+
isdefined(y, 1) # => ϕ(true, false)
391+
end
392+
@test count(iscall((src, isdefined)), src.code) == 0
393+
394+
src = code_typed1((Int,)) do n
395+
s = 0
396+
itr = 1:n
397+
st = iterate(itr)
398+
while isdefined(st, 2)
399+
i, st = itr
400+
s += i
401+
st = iterate(itr, st)
402+
end
403+
s
404+
end
405+
@test !any(src.code) do @nospecialize x
406+
iscall((src, isdefined), x) && argextype(x.args[2], src) isa Union
407+
end
408+
end
409+
299410
mutable struct Foo30594; x::Float64; end
300411
Base.copy(x::Foo30594) = Foo30594(x.x)
301412
function add!(p::Foo30594, off::Foo30594)
@@ -611,48 +722,6 @@ exc39508 = ErrorException("expected")
611722
end
612723
@test test39508() === exc39508
613724

614-
let # `sroa_pass!` should work with constant globals
615-
# immutable pass
616-
src = @eval Module() begin
617-
const REF_FLD = :x
618-
struct ImmutableRef{T}
619-
x::T
620-
end
621-
622-
code_typed((Int,)) do x
623-
r = ImmutableRef{Int}(x) # should be eliminated
624-
x = getfield(r, REF_FLD) # should be eliminated
625-
return sin(x)
626-
end |> only |> first
627-
end
628-
@test !any(src.code) do @nospecialize(stmt)
629-
Meta.isexpr(stmt, :call) || return false
630-
ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES)
631-
return Core.Compiler.widenconst(ft) == typeof(getfield)
632-
end
633-
@test !any(src.code) do @nospecialize(stmt)
634-
return Meta.isexpr(stmt, :new)
635-
end
636-
637-
# mutable pass
638-
src = @eval Module() begin
639-
const REF_FLD = :x
640-
code_typed() do
641-
r = Ref{Int}(42) # should be eliminated
642-
x = getfield(r, REF_FLD) # should be eliminated
643-
return sin(x)
644-
end |> only |> first
645-
end
646-
@test !any(src.code) do @nospecialize(stmt)
647-
Meta.isexpr(stmt, :call) || return false
648-
ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES)
649-
return Core.Compiler.widenconst(ft) == typeof(getfield)
650-
end
651-
@test !any(src.code) do @nospecialize(stmt)
652-
return Meta.isexpr(stmt, :new)
653-
end
654-
end
655-
656725
let
657726
# `typeassert` elimination after SROA
658727
# NOTE we can remove this optimization once inference is able to reason about memory-effects
@@ -666,11 +735,7 @@ let
666735
end |> only |> first
667736
end
668737
# eliminate `typeassert(x2.x, Foo)`
669-
@test all(src.code) do @nospecialize stmt
670-
Meta.isexpr(stmt, :call) || return true
671-
ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES)
672-
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
673-
end
738+
@test count(iscall((src, typeassert)), src.code) == 0
674739
end
675740

676741
let

0 commit comments

Comments
 (0)