Skip to content

Commit a8f419a

Browse files
committed
Fix thermalfluid benchmark & other misc fixes
1 parent e7a97eb commit a8f419a

File tree

4 files changed

+22
-26
lines changed

4 files changed

+22
-26
lines changed

src/analysis/cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function add_equation_row!(graph, solvable_graph, ieq::Int, inc::Incidence)
6060
for (v, coeff) in zip(rowvals(inc.row), nonzeros(inc.row))
6161
v == 1 && continue
6262
add_edge!(graph, ieq, v-1)
63-
coeff !== nonlinear && add_edge!(solvable_graph, ieq, v-1)
63+
isa(coeff, Float64) && add_edge!(solvable_graph, ieq, v-1)
6464
end
6565
end
6666
add_equation_row!(graph, solvable_graph, ieq::Int, c::Const) = nothing

src/analysis/lattice.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ const nonlinear = Linearity()
7676

7777
join_linearity(a::Linearity, b::Real) = a
7878
join_linearity(a::Real, b::Linearity) = b
79-
join_linearity(a::Real, b::Real) = a == b ? promote(a, b) : Linearity(time_dependent = false, state_dependent = false, nonlinear = false)
79+
join_linearity(a::Real, b::Real) = a == b ? a : linear
8080
function join_linearity(a::Linearity, b::Linearity)
8181
(a.nonlinear | b.nonlinear) && return nonlinear
8282
return Linearity(; time_dependent = a.time_dependent | b.time_dependent, state_dependent = a.state_dependent | b.state_dependent, nonlinear = false)
@@ -698,11 +698,7 @@ function Compiler.tmerge(🥬::EqStructureLattice, @nospecialize(a), @nospeciali
698698
merged_typ = Compiler.tmerge(Compiler.widenlattice(🥬), a.typ, b.typ)
699699
row = _zero_row()
700700
for i in union(rowvals(a.row), rowvals(b.row))
701-
if a.row[i] == b.row[i]
702-
row[i] = a.row[i]
703-
else
704-
row[i] = nonlinear
705-
end
701+
row[i] = join_linearity(a.row[i], b.row[i])
706702
end
707703
return Incidence(merged_typ, row)
708704
elseif isa(b, Const)

src/transform/autodiff/index_lowering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function index_lowering_ad!(state::TransformationState, key::TornCacheKey)
7272
# Mark all non-trivial `ddt()` statements as ones that we should differentiate
7373
diff_ssas = Pair{SSAValue,Int}[]
7474
for i = 1:length(ir.stmts)
75-
if is_known_invoke(ir.stmts[i][:stmt], ddt, ir) && !is_const_plus_var_linear(argextype(ir.stmts[i][:stmt].args[end], ir))
75+
if is_known_invoke(ir.stmts[i][:stmt], ddt, ir) && !is_const_plus_var_known_linear(argextype(ir.stmts[i][:stmt].args[end], ir))
7676
push!(diff_ssas, SSAValue(i) => 0)
7777
end
7878
end

src/transform/tearing/schedule.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ Base.setindex!(rno::RenameOverlayVector, @nospecialize(val), i::Int) =
121121
Base.setindex!(rno.overlay, val, i)
122122
Base.size(rno::RenameOverlayVector) = Base.size(rno.base)
123123

124-
is_var_part_linear(incT::Const) = true
125-
is_var_part_linear(incT::Incidence) = !any(==(nonlinear), incT.row)
126-
is_const_plus_var_linear(incT) = is_var_part_linear(incT) && isa(incT.typ, Const)
127-
is_const_plus_var_linear(incT::Const) = true
124+
is_var_part_known_linear(incT::Const) = true
125+
is_var_part_known_linear(incT::Incidence) = all(x -> isa(x, Float64), incT.row)
126+
is_const_plus_var_known_linear(incT) = is_var_part_known_linear(incT) && isa(incT.typ, Const)
127+
is_const_plus_var_known_linear(incT::Const) = true
128128

129-
is_fully_state_linear(incT, param_vars) = is_const_plus_var_linear(incT) && is_fully_state_linear(incT.typ, param_vars)
129+
is_fully_state_linear(incT, param_vars) = is_const_plus_var_known_linear(incT) && is_fully_state_linear(incT.typ, param_vars)
130130
is_fully_state_linear(incT::Const, param_vars) = iszero(incT.val)
131131

132132
function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, ordinal, val::Union{SSAValue, Argument}, ssa_rename::AbstractVector{Any}; vars, schedule_missing_var! = nothing)
@@ -182,7 +182,7 @@ function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, ordinal,
182182
end
183183

184184
# TODO: SICM
185-
if !is_const_plus_var_linear(typ::Incidence)
185+
if !is_const_plus_var_known_linear(typ::Incidence)
186186
this_nonlinear = schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, ordinal, arg, ssa_rename; vars, schedule_missing_var!)
187187
else
188188
if @isdefined(result)
@@ -202,7 +202,7 @@ function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, ordinal,
202202
return schedule_incidence!(compact, this_nonlinear, typ, -1, inst[:line]; vars, schedule_missing_var!)[1]
203203
end
204204

205-
if is_const_plus_var_linear(incT)
205+
if is_const_plus_var_known_linear(incT)
206206
ret = schedule_incidence!(compact, nothing, info.result.extended_rt, -1, inst[:line]; vars=
207207
[arg === nothing ? 0.0 : arg for arg in args], schedule_missing_var! = var->error((var, incT, args)))[1]
208208
else
@@ -329,7 +329,7 @@ function compute_eq_schedule(key::TornCacheKey, total_incidence, result, mss::St
329329
i in this_callee_eqs && continue # We already scheduled this
330330
callee_incidence = callee_info.result.total_incidence[i]
331331
incidence = apply_linear_incidence(nothing, callee_incidence, nothing, callee_info.mapping)
332-
if is_const_plus_var_linear(incidence)
332+
if is_const_plus_var_known_linear(incidence)
333333
# No non-linear components - skip it
334334
push!(previously_scheduled_or_ignored, i)
335335
continue
@@ -364,7 +364,7 @@ function compute_eq_schedule(key::TornCacheKey, total_incidence, result, mss::St
364364
end
365365

366366
function schedule_eq!(eq; force=false)
367-
if is_const_plus_var_linear(total_incidence[eq])
367+
if is_const_plus_var_known_linear(total_incidence[eq])
368368
# This is a linear equation, we can schedule it now
369369
return true
370370
end
@@ -382,7 +382,7 @@ function compute_eq_schedule(key::TornCacheKey, total_incidence, result, mss::St
382382
schedule = get!(callee_schedules, ssa, Vector{Pair{BitSet, BitSet}}())
383383
callee_info = result.ir[SSAValue(ssa.id)][:info]::MappingInfo
384384

385-
if is_const_plus_var_linear(callee_info.result.total_incidence[callee_eq])
385+
if is_const_plus_var_known_linear(callee_info.result.total_incidence[callee_eq])
386386
# This portion of the calle is linear, we can schedule it
387387
continue
388388
end
@@ -626,7 +626,7 @@ function matching_for_key(state::TransformationState, key::TornCacheKey)
626626
invview(var_eq_matching)[eq] = WrongEquation()
627627
continue
628628
end
629-
if is_const_plus_var_linear(total_incidence[eq]) && invview(var_eq_matching)[eq] === unassigned
629+
if is_const_plus_var_known_linear(total_incidence[eq]) && invview(var_eq_matching)[eq] === unassigned
630630
invview(var_eq_matching)[eq] = FullyLinear()
631631
end
632632
end
@@ -787,8 +787,8 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
787787
push!(callee_explicit_eqs, callee_eq)
788788
else
789789
var = invview(var_eq_matching)[caller_eq]
790-
(is_const_plus_var_linear(callee_result.total_incidence[callee_eq]) ||
791-
is_const_plus_var_linear(total_incidence[caller_eq])) && continue
790+
(is_const_plus_var_known_linear(callee_result.total_incidence[callee_eq]) ||
791+
is_const_plus_var_known_linear(total_incidence[caller_eq])) && continue
792792
@assert var !== unassigned
793793
if !any(out->callee_eq in out[2], callee_var_schedule)
794794
display(mss)
@@ -847,7 +847,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
847847
for var in callee_param_vars
848848
varmap = info.mapping.var_coeffs[var]
849849
nonlin = nothing
850-
if !is_const_plus_var_linear(varmap)
850+
if !is_const_plus_var_known_linear(varmap)
851851
nonlin = schedule_nonlinear!(compact, key.param_vars, var_eq_matching, ir, ordinal, stmt.args[1+var], sicm_rename; vars=var_sols)
852852
end
853853
(argval, _) = schedule_incidence!(compact,
@@ -901,7 +901,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
901901
result.eq_callee_mapping[baseeq(result, structure, eq)] === nothing &&
902902
eqclassification(result, structure, eq) != External &&
903903
eqkind(result, structure, eq) == Intrinsics.Always &&
904-
!is_const_plus_var_linear(total_incidence[eq])
904+
!is_const_plus_var_known_linear(total_incidence[eq])
905905
push!(eq_orders[end], eq)
906906
end
907907
end
@@ -1000,7 +1000,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
10001000
for var in callee_in_vars
10011001
nonlin = nothing
10021002
varmap = info.mapping.var_coeffs[var]
1003-
if !is_const_plus_var_linear(varmap)
1003+
if !is_const_plus_var_known_linear(varmap)
10041004
nonlin = schedule_nonlinear!(compact1, key.param_vars, var_eq_matching, ir, ordinal, eqinst[:stmt].args[1+var], ssa_rename; vars=var_sols, schedule_missing_var!)
10051005
end
10061006
(argval, _) = schedule_incidence!(compact1,
@@ -1044,7 +1044,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
10441044
var = invview(var_eq_matching)[eq]
10451045

10461046
incT = state.total_incidence[eq]
1047-
anynonlinear = !is_const_plus_var_linear(incT)
1047+
anynonlinear = !is_const_plus_var_known_linear(incT)
10481048
nonlinearssa = nothing
10491049
if anynonlinear
10501050
if isa(var, Int) && isa(vars[var], SolvedVariable)
@@ -1086,7 +1086,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
10861086
if isa(var, Int)
10871087
curval = nonlinearssa
10881088
(curval, thiscoeff) = schedule_incidence!(compact1, curval, incT, var, line; vars=var_sols, schedule_missing_var!)
1089-
@assert thiscoeff != nonlinear
1089+
@assert isa(thiscoeff, Float64)
10901090
curval = ir_mul_const!(compact1, line, 1/thiscoeff, curval)
10911091
var_sols[var] = isa(curval, SSAValue) ? CarriedSSAValue(ordinal, curval.id) : curval
10921092
insert_solved_var_here!(compact1, var, curval, line)

0 commit comments

Comments
 (0)