Skip to content

Commit a8aa0b0

Browse files
authored
Improve hessian sparsity detection (#2882)
1 parent fda0a14 commit a8aa0b0

File tree

4 files changed

+161
-124
lines changed

4 files changed

+161
-124
lines changed

src/Nonlinear/ReverseAD/graph_tools.jl

Lines changed: 153 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,75 @@ function _compute_gradient_sparsity!(
175175
return
176176
end
177177

178+
"""
179+
_get_nonlinear_child_interactions(
180+
node::Nonlinear.Node,
181+
num_children::Int,
182+
)
183+
184+
Get the list of nonlinear child interaction pairs for a node.
185+
Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed)
186+
that have nonlinear interactions.
187+
188+
For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1
189+
and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not.
190+
191+
For functions like `+` or `-`, the result is `[]` since there are no nonlinear
192+
interactions between children.
193+
"""
194+
function _get_nonlinear_child_interactions(
195+
node::Nonlinear.Node,
196+
num_children::Int,
197+
)::Vector{Tuple{Int,Int}}
198+
if node.type == Nonlinear.NODE_CALL_UNIVARIATE
199+
@assert num_children == 1
200+
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, node.index, nothing)
201+
# Univariate operators :+ and :- don't create interactions
202+
if op in (:+, :-)
203+
return Tuple{Int,Int}[]
204+
else
205+
return [(1, 1)]
206+
end
207+
elseif node.type == Nonlinear.NODE_CALL_MULTIVARIATE
208+
op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, node.index, nothing)
209+
if op in (:+, :-, :ifelse, :min, :max)
210+
# No nonlinear interactions between children
211+
return Tuple{Int,Int}[]
212+
elseif op == :*
213+
# All pairs of distinct children interact nonlinearly
214+
result = Tuple{Int,Int}[]
215+
for i in 1:num_children
216+
for j in 1:(i-1)
217+
push!(result, (j, i))
218+
end
219+
end
220+
return result
221+
elseif op == :/
222+
@assert num_children == 2
223+
# The numerator doesn't have a nonlinear interaction with itself.
224+
return [(1, 2), (2, 2)]
225+
else
226+
# Conservative: assume all pairs interact
227+
result = Tuple{Int,Int}[]
228+
for i in 1:num_children
229+
for j in 1:i
230+
push!(result, (j, i))
231+
end
232+
end
233+
return result
234+
end
235+
else
236+
# Logic and comparison nodes don't generate hessian terms.
237+
# Subexpression nodes are special cased.
238+
return Tuple{Int,Int}[]
239+
end
240+
end
241+
178242
"""
179243
_compute_hessian_sparsity(
180244
nodes::Vector{Nonlinear.Node},
181245
adj,
182246
input_linearity::Vector{Linearity},
183-
indexedset::Coloring.IndexedSet,
184247
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
185248
subexpression_variables::Vector{Vector{Int}},
186249
)
@@ -193,142 +256,118 @@ Compute the sparsity pattern the Hessian of an expression.
193256
* `subexpression_variables` is the list of all variables which appear in a
194257
subexpression (including recursively).
195258
196-
Idea: consider the (non)linearity of a node *with respect to the output*. The
197-
children of any node which is nonlinear with respect to the output should have
198-
nonlinear interactions, hence nonzeros in the hessian. This is not true in
199-
general, but holds for everything we consider.
200-
201-
A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions
202-
like that. By "nonlinear with respect to the output", we mean that the output
203-
depends nonlinearly on the value of the node, regardless of how the node itself
204-
depends on the input.
259+
Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian.
205260
"""
206261
function _compute_hessian_sparsity(
207262
nodes::Vector{Nonlinear.Node},
208263
adj,
209264
input_linearity::Vector{Linearity},
210-
indexedset::Coloring.IndexedSet,
211265
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
212266
subexpression_variables::Vector{Vector{Int}},
213267
)
214-
# So start at the root of the tree and classify the linearity wrt the output.
215-
# For each nonlinear node, do a mini DFS and collect the list of children.
216-
# Add a nonlinear interaction between all children of a nonlinear node.
217268
edge_list = Set{Tuple{Int,Int}}()
218-
nonlinear_wrt_output = fill(false, length(nodes))
219269
children_arr = SparseArrays.rowvals(adj)
220-
stack = Int[]
221-
stack_ignore = Bool[]
222-
nonlinear_group = indexedset
223-
if length(nodes) == 1 && nodes[1].type == Nonlinear.NODE_SUBEXPRESSION
224-
# Subexpression comes in linearly, so append edge_list
225-
for ij in subexpression_edgelist[nodes[1].index]
226-
push!(edge_list, ij)
227-
end
228-
end
229-
for k in 2:length(nodes)
230-
nod = nodes[k]
231-
@assert nod.type != Nonlinear.NODE_MOI_VARIABLE
232-
if nonlinear_wrt_output[k]
233-
continue # already seen this node one way or another
234-
elseif input_linearity[k] == CONSTANT
235-
continue # definitely not nonlinear
270+
# Stack entry: (node_index, child_group_index)
271+
stack = Tuple{Int,Int}[]
272+
# Map from child_group_index to variable indices
273+
child_group_variables = Dict{Int,Set{Int}}()
274+
for (k, node) in enumerate(nodes)
275+
@assert node.type != Nonlinear.NODE_MOI_VARIABLE
276+
if input_linearity[k] == CONSTANT
277+
continue # No hessian contribution from constant nodes
236278
end
237-
@assert !nonlinear_wrt_output[nod.parent]
238-
# check if the parent depends nonlinearly on the value of this node
239-
par = nodes[nod.parent]
240-
if par.type == Nonlinear.NODE_CALL_UNIVARIATE
241-
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, par.index, nothing)
242-
if op === nothing || (op != :+ && op != :-)
243-
nonlinear_wrt_output[k] = true
279+
# Check if this node has nonlinear child interactions
280+
children_idx = SparseArrays.nzrange(adj, k)
281+
num_children = length(children_idx)
282+
interactions = _get_nonlinear_child_interactions(node, num_children)
283+
if !isempty(interactions)
284+
# This node has nonlinear child interactions, so collect variables
285+
# from its children
286+
empty!(child_group_variables)
287+
# DFS from all children, tracking child index
288+
for (child_position, cidx) in enumerate(children_idx)
289+
child_node_idx = children_arr[cidx]
290+
push!(stack, (child_node_idx, child_position))
244291
end
245-
elseif par.type == Nonlinear.NODE_CALL_MULTIVARIATE
246-
op = get(
247-
Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS,
248-
par.index,
249-
nothing,
250-
)
251-
if op === nothing
252-
nonlinear_wrt_output[k] = true
253-
elseif op in (:+, :-, :ifelse)
254-
# pass
255-
elseif op == :*
256-
# check if all siblings are constant
257-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
258-
if !all(
259-
i ->
260-
input_linearity[children_arr[i]] == CONSTANT ||
261-
children_arr[i] == k,
262-
sibling_idx,
263-
)
264-
# at least one sibling isn't constant
265-
nonlinear_wrt_output[k] = true
292+
while length(stack) > 0
293+
r, child_group_idx = pop!(stack)
294+
# Don't traverse into logical conditions or comparisons
295+
if nodes[r].type == Nonlinear.NODE_LOGIC ||
296+
nodes[r].type == Nonlinear.NODE_COMPARISON
297+
continue
266298
end
267-
elseif op == :/
268-
# check if denominator is nonconstant
269-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
270-
if input_linearity[children_arr[last(sibling_idx)]] != CONSTANT
271-
nonlinear_wrt_output[k] = true
299+
r_children_idx = SparseArrays.nzrange(adj, r)
300+
for cidx in r_children_idx
301+
push!(stack, (children_arr[cidx], child_group_idx))
302+
end
303+
if nodes[r].type == Nonlinear.NODE_VARIABLE
304+
if !haskey(child_group_variables, child_group_idx)
305+
child_group_variables[child_group_idx] = Set{Int}()
306+
end
307+
push!(
308+
child_group_variables[child_group_idx],
309+
nodes[r].index,
310+
)
311+
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
312+
sub_vars = subexpression_variables[nodes[r].index]
313+
if !haskey(child_group_variables, child_group_idx)
314+
child_group_variables[child_group_idx] = Set{Int}()
315+
end
316+
union!(child_group_variables[child_group_idx], sub_vars)
272317
end
273-
else
274-
nonlinear_wrt_output[k] = true
275318
end
276-
end
277-
if nod.type == Nonlinear.NODE_SUBEXPRESSION && !nonlinear_wrt_output[k]
278-
# subexpression comes in linearly, so append edge_list
279-
for ij in subexpression_edgelist[nod.index]
319+
_add_hessian_edges!(edge_list, interactions, child_group_variables)
320+
elseif node.type == Nonlinear.NODE_SUBEXPRESSION
321+
for ij in subexpression_edgelist[node.index]
280322
push!(edge_list, ij)
281323
end
282324
end
283-
if !nonlinear_wrt_output[k]
284-
continue
285-
end
286-
# do a DFS from here, including all children
287-
@assert isempty(stack)
288-
@assert isempty(stack_ignore)
289-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
290-
for sidx in sibling_idx
291-
push!(stack, children_arr[sidx])
292-
push!(stack_ignore, false)
293-
end
294-
empty!(nonlinear_group)
295-
while length(stack) > 0
296-
r = pop!(stack)
297-
should_ignore = pop!(stack_ignore)
298-
nonlinear_wrt_output[r] = true
299-
if nodes[r].type == Nonlinear.NODE_LOGIC ||
300-
nodes[r].type == Nonlinear.NODE_COMPARISON
301-
# don't count the nonlinear interactions inside
302-
# logical conditions or comparisons
303-
should_ignore = true
304-
end
305-
children_idx = SparseArrays.nzrange(adj, r)
306-
for cidx in children_idx
307-
push!(stack, children_arr[cidx])
308-
push!(stack_ignore, should_ignore)
309-
end
310-
if should_ignore
311-
continue
312-
end
313-
if nodes[r].type == Nonlinear.NODE_VARIABLE
314-
push!(nonlinear_group, nodes[r].index)
315-
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
316-
# append all variables in subexpression
317-
union!(nonlinear_group, subexpression_variables[nodes[r].index])
325+
end
326+
return edge_list
327+
end
328+
329+
"""
330+
_add_hessian_edges!(
331+
edge_list::Set{Tuple{Int,Int}},
332+
interactions::Vector{Tuple{Int,Int}},
333+
child_variables::Dict{Int,Set{Int}},
334+
)
335+
336+
Add hessian edges based on the operator's nonlinear interaction pattern.
337+
"""
338+
function _add_hessian_edges!(
339+
edge_list::Set{Tuple{Int,Int}},
340+
interactions::Vector{Tuple{Int,Int}},
341+
child_variables::Dict{Int,Set{Int}},
342+
)
343+
for (child_i, child_j) in interactions
344+
if child_i == child_j
345+
# Within-child interactions: add all pairs from a single child
346+
if haskey(child_variables, child_i)
347+
vars = child_variables[child_i]
348+
for vi in vars
349+
for vj in vars
350+
i, j = minmax(vi, vj)
351+
push!(edge_list, (j, i))
352+
end
353+
end
318354
end
319-
end
320-
for i_ in 1:nonlinear_group.nnz
321-
i = nonlinear_group.nzidx[i_]
322-
for j_ in 1:nonlinear_group.nnz
323-
j = nonlinear_group.nzidx[j_]
324-
if j > i
325-
continue # Only lower triangle.
355+
else
356+
# Between-child interactions: add pairs from different children
357+
if haskey(child_variables, child_i) &&
358+
haskey(child_variables, child_j)
359+
vars_i = child_variables[child_i]
360+
vars_j = child_variables[child_j]
361+
for vi in vars_i
362+
for vj in vars_j
363+
i, j = minmax(vi, vj)
364+
push!(edge_list, (j, i))
365+
end
326366
end
327-
push!(edge_list, (i, j))
328367
end
329368
end
330369
end
331-
return edge_list
370+
return
332371
end
333372

334373
"""

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
9393
subex.nodes,
9494
subex.adj,
9595
linearity,
96-
coloring_storage,
9796
subexpression_edgelist,
9897
subexpression_variables,
9998
)

src/Nonlinear/ReverseAD/types.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ struct _FunctionStorage
9191
nodes,
9292
adj,
9393
linearity,
94-
coloring_storage,
9594
subexpression_edgelist,
9695
subexpression_variables,
9796
)

test/Nonlinear/ReverseAD.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ function test_linearity()
561561
nodes,
562562
adj,
563563
ret,
564-
indexed_set,
565564
Set{Tuple{Int,Int}}[],
566565
Vector{Int}[],
567566
)
@@ -585,12 +584,7 @@ function test_linearity()
585584
[1, 2],
586585
)
587586
_test_linearity(:(3 * 4 * ($x + $y)), ReverseAD.LINEAR)
588-
_test_linearity(
589-
:($z * $y),
590-
ReverseAD.NONLINEAR,
591-
Set([(3, 2), (3, 3), (2, 2)]),
592-
[2, 3],
593-
)
587+
_test_linearity(:($z * $y), ReverseAD.NONLINEAR, Set([(3, 2)]), [2, 3])
594588
_test_linearity(:(3 + 4), ReverseAD.CONSTANT)
595589
_test_linearity(:(sin(3) + $x), ReverseAD.LINEAR)
596590
_test_linearity(
@@ -635,6 +629,12 @@ function test_linearity()
635629
Set([(1, 1)]),
636630
[1],
637631
)
632+
_test_linearity(
633+
:(($x + $y)/$z),
634+
ReverseAD.NONLINEAR,
635+
Set([(3, 3), (3, 2), (3, 1)]),
636+
[1, 2, 3],
637+
)
638638
return
639639
end
640640

@@ -1416,7 +1416,7 @@ function test_hessian_reinterpret_unsafe()
14161416
x_v = ones(5)
14171417
MOI.eval_hessian_lagrangian(evaluator, H, x_v, 0.0, [1.0, 1.0])
14181418
@test count(isapprox.(H, 1.0; atol = 1e-8)) == 3
1419-
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 6
1419+
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 5
14201420
@test sort(H_s[round.(Bool, H)]) == [(3, 1), (3, 2), (5, 4)]
14211421
return
14221422
end

0 commit comments

Comments
 (0)