@@ -175,12 +175,75 @@ function _compute_gradient_sparsity!(
175175 return
176176end
177177
178+ """
179+ _get_nonlinear_child_interactions(
180+ node::Nonlinear.Node,
181+ num_children::Int,
182+ )
183+
184+ Get the list of nonlinear child interaction pairs for a node.
185+ Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed)
186+ that have nonlinear interactions.
187+
188+ For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1
189+ and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not.
190+
191+ For functions like `+` or `-`, the result is `[]` since there are no nonlinear
192+ interactions between children.
193+ """
194+ function _get_nonlinear_child_interactions (
195+ node:: Nonlinear.Node ,
196+ num_children:: Int ,
197+ ):: Vector{Tuple{Int,Int}}
198+ if node. type == Nonlinear. NODE_CALL_UNIVARIATE
199+ @assert num_children == 1
200+ op = get (Nonlinear. DEFAULT_UNIVARIATE_OPERATORS, node. index, nothing )
201+ # Univariate operators :+ and :- don't create interactions
202+ if op in (:+ , :- )
203+ return Tuple{Int,Int}[]
204+ else
205+ return [(1 , 1 )]
206+ end
207+ elseif node. type == Nonlinear. NODE_CALL_MULTIVARIATE
208+ op = get (Nonlinear. DEFAULT_MULTIVARIATE_OPERATORS, node. index, nothing )
209+ if op in (:+ , :- , :ifelse , :min , :max )
210+ # No nonlinear interactions between children
211+ return Tuple{Int,Int}[]
212+ elseif op == :*
213+ # All pairs of distinct children interact nonlinearly
214+ result = Tuple{Int,Int}[]
215+ for i in 1 : num_children
216+ for j in 1 : (i- 1 )
217+ push! (result, (j, i))
218+ end
219+ end
220+ return result
221+ elseif op == :/
222+ @assert num_children == 2
223+ # The numerator doesn't have a nonlinear interaction with itself.
224+ return [(1 , 2 ), (2 , 2 )]
225+ else
226+ # Conservative: assume all pairs interact
227+ result = Tuple{Int,Int}[]
228+ for i in 1 : num_children
229+ for j in 1 : i
230+ push! (result, (j, i))
231+ end
232+ end
233+ return result
234+ end
235+ else
236+ # Logic and comparison nodes don't generate hessian terms.
237+ # Subexpression nodes are special cased.
238+ return Tuple{Int,Int}[]
239+ end
240+ end
241+
178242"""
179243 _compute_hessian_sparsity(
180244 nodes::Vector{Nonlinear.Node},
181245 adj,
182246 input_linearity::Vector{Linearity},
183- indexedset::Coloring.IndexedSet,
184247 subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
185248 subexpression_variables::Vector{Vector{Int}},
186249 )
@@ -193,142 +256,118 @@ Compute the sparsity pattern the Hessian of an expression.
193256 * `subexpression_variables` is the list of all variables which appear in a
194257 subexpression (including recursively).
195258
196- Idea: consider the (non)linearity of a node *with respect to the output*. The
197- children of any node which is nonlinear with respect to the output should have
198- nonlinear interactions, hence nonzeros in the hessian. This is not true in
199- general, but holds for everything we consider.
200-
201- A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions
202- like that. By "nonlinear with respect to the output", we mean that the output
203- depends nonlinearly on the value of the node, regardless of how the node itself
204- depends on the input.
259+ Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian.
205260"""
206261function _compute_hessian_sparsity (
207262 nodes:: Vector{Nonlinear.Node} ,
208263 adj,
209264 input_linearity:: Vector{Linearity} ,
210- indexedset:: Coloring.IndexedSet ,
211265 subexpression_edgelist:: Vector{Set{Tuple{Int,Int}}} ,
212266 subexpression_variables:: Vector{Vector{Int}} ,
213267)
214- # So start at the root of the tree and classify the linearity wrt the output.
215- # For each nonlinear node, do a mini DFS and collect the list of children.
216- # Add a nonlinear interaction between all children of a nonlinear node.
217268 edge_list = Set {Tuple{Int,Int}} ()
218- nonlinear_wrt_output = fill (false , length (nodes))
219269 children_arr = SparseArrays. rowvals (adj)
220- stack = Int[]
221- stack_ignore = Bool[]
222- nonlinear_group = indexedset
223- if length (nodes) == 1 && nodes[1 ]. type == Nonlinear. NODE_SUBEXPRESSION
224- # Subexpression comes in linearly, so append edge_list
225- for ij in subexpression_edgelist[nodes[1 ]. index]
226- push! (edge_list, ij)
227- end
228- end
229- for k in 2 : length (nodes)
230- nod = nodes[k]
231- @assert nod. type != Nonlinear. NODE_MOI_VARIABLE
232- if nonlinear_wrt_output[k]
233- continue # already seen this node one way or another
234- elseif input_linearity[k] == CONSTANT
235- continue # definitely not nonlinear
270+ # Stack entry: (node_index, child_group_index)
271+ stack = Tuple{Int,Int}[]
272+ # Map from child_group_index to variable indices
273+ child_group_variables = Dict {Int,Set{Int}} ()
274+ for (k, node) in enumerate (nodes)
275+ @assert node. type != Nonlinear. NODE_MOI_VARIABLE
276+ if input_linearity[k] == CONSTANT
277+ continue # No hessian contribution from constant nodes
236278 end
237- @assert ! nonlinear_wrt_output[nod. parent]
238- # check if the parent depends nonlinearly on the value of this node
239- par = nodes[nod. parent]
240- if par. type == Nonlinear. NODE_CALL_UNIVARIATE
241- op = get (Nonlinear. DEFAULT_UNIVARIATE_OPERATORS, par. index, nothing )
242- if op === nothing || (op != :+ && op != :- )
243- nonlinear_wrt_output[k] = true
279+ # Check if this node has nonlinear child interactions
280+ children_idx = SparseArrays. nzrange (adj, k)
281+ num_children = length (children_idx)
282+ interactions = _get_nonlinear_child_interactions (node, num_children)
283+ if ! isempty (interactions)
284+ # This node has nonlinear child interactions, so collect variables
285+ # from its children
286+ empty! (child_group_variables)
287+ # DFS from all children, tracking child index
288+ for (child_position, cidx) in enumerate (children_idx)
289+ child_node_idx = children_arr[cidx]
290+ push! (stack, (child_node_idx, child_position))
244291 end
245- elseif par. type == Nonlinear. NODE_CALL_MULTIVARIATE
246- op = get (
247- Nonlinear. DEFAULT_MULTIVARIATE_OPERATORS,
248- par. index,
249- nothing ,
250- )
251- if op === nothing
252- nonlinear_wrt_output[k] = true
253- elseif op in (:+ , :- , :ifelse )
254- # pass
255- elseif op == :*
256- # check if all siblings are constant
257- sibling_idx = SparseArrays. nzrange (adj, nod. parent)
258- if ! all (
259- i ->
260- input_linearity[children_arr[i]] == CONSTANT ||
261- children_arr[i] == k,
262- sibling_idx,
263- )
264- # at least one sibling isn't constant
265- nonlinear_wrt_output[k] = true
292+ while length (stack) > 0
293+ r, child_group_idx = pop! (stack)
294+ # Don't traverse into logical conditions or comparisons
295+ if nodes[r]. type == Nonlinear. NODE_LOGIC ||
296+ nodes[r]. type == Nonlinear. NODE_COMPARISON
297+ continue
266298 end
267- elseif op == :/
268- # check if denominator is nonconstant
269- sibling_idx = SparseArrays. nzrange (adj, nod. parent)
270- if input_linearity[children_arr[last (sibling_idx)]] != CONSTANT
271- nonlinear_wrt_output[k] = true
299+ r_children_idx = SparseArrays. nzrange (adj, r)
300+ for cidx in r_children_idx
301+ push! (stack, (children_arr[cidx], child_group_idx))
302+ end
303+ if nodes[r]. type == Nonlinear. NODE_VARIABLE
304+ if ! haskey (child_group_variables, child_group_idx)
305+ child_group_variables[child_group_idx] = Set {Int} ()
306+ end
307+ push! (
308+ child_group_variables[child_group_idx],
309+ nodes[r]. index,
310+ )
311+ elseif nodes[r]. type == Nonlinear. NODE_SUBEXPRESSION
312+ sub_vars = subexpression_variables[nodes[r]. index]
313+ if ! haskey (child_group_variables, child_group_idx)
314+ child_group_variables[child_group_idx] = Set {Int} ()
315+ end
316+ union! (child_group_variables[child_group_idx], sub_vars)
272317 end
273- else
274- nonlinear_wrt_output[k] = true
275318 end
276- end
277- if nod. type == Nonlinear. NODE_SUBEXPRESSION && ! nonlinear_wrt_output[k]
278- # subexpression comes in linearly, so append edge_list
279- for ij in subexpression_edgelist[nod. index]
319+ _add_hessian_edges! (edge_list, interactions, child_group_variables)
320+ elseif node. type == Nonlinear. NODE_SUBEXPRESSION
321+ for ij in subexpression_edgelist[node. index]
280322 push! (edge_list, ij)
281323 end
282324 end
283- if ! nonlinear_wrt_output[k]
284- continue
285- end
286- # do a DFS from here, including all children
287- @assert isempty (stack)
288- @assert isempty (stack_ignore)
289- sibling_idx = SparseArrays. nzrange (adj, nod. parent)
290- for sidx in sibling_idx
291- push! (stack, children_arr[sidx])
292- push! (stack_ignore, false )
293- end
294- empty! (nonlinear_group)
295- while length (stack) > 0
296- r = pop! (stack)
297- should_ignore = pop! (stack_ignore)
298- nonlinear_wrt_output[r] = true
299- if nodes[r]. type == Nonlinear. NODE_LOGIC ||
300- nodes[r]. type == Nonlinear. NODE_COMPARISON
301- # don't count the nonlinear interactions inside
302- # logical conditions or comparisons
303- should_ignore = true
304- end
305- children_idx = SparseArrays. nzrange (adj, r)
306- for cidx in children_idx
307- push! (stack, children_arr[cidx])
308- push! (stack_ignore, should_ignore)
309- end
310- if should_ignore
311- continue
312- end
313- if nodes[r]. type == Nonlinear. NODE_VARIABLE
314- push! (nonlinear_group, nodes[r]. index)
315- elseif nodes[r]. type == Nonlinear. NODE_SUBEXPRESSION
316- # append all variables in subexpression
317- union! (nonlinear_group, subexpression_variables[nodes[r]. index])
325+ end
326+ return edge_list
327+ end
328+
329+ """
330+ _add_hessian_edges!(
331+ edge_list::Set{Tuple{Int,Int}},
332+ interactions::Vector{Tuple{Int,Int}},
333+ child_variables::Dict{Int,Set{Int}},
334+ )
335+
336+ Add hessian edges based on the operator's nonlinear interaction pattern.
337+ """
338+ function _add_hessian_edges! (
339+ edge_list:: Set{Tuple{Int,Int}} ,
340+ interactions:: Vector{Tuple{Int,Int}} ,
341+ child_variables:: Dict{Int,Set{Int}} ,
342+ )
343+ for (child_i, child_j) in interactions
344+ if child_i == child_j
345+ # Within-child interactions: add all pairs from a single child
346+ if haskey (child_variables, child_i)
347+ vars = child_variables[child_i]
348+ for vi in vars
349+ for vj in vars
350+ i, j = minmax (vi, vj)
351+ push! (edge_list, (j, i))
352+ end
353+ end
318354 end
319- end
320- for i_ in 1 : nonlinear_group. nnz
321- i = nonlinear_group. nzidx[i_]
322- for j_ in 1 : nonlinear_group. nnz
323- j = nonlinear_group. nzidx[j_]
324- if j > i
325- continue # Only lower triangle.
355+ else
356+ # Between-child interactions: add pairs from different children
357+ if haskey (child_variables, child_i) &&
358+ haskey (child_variables, child_j)
359+ vars_i = child_variables[child_i]
360+ vars_j = child_variables[child_j]
361+ for vi in vars_i
362+ for vj in vars_j
363+ i, j = minmax (vi, vj)
364+ push! (edge_list, (j, i))
365+ end
326366 end
327- push! (edge_list, (i, j))
328367 end
329368 end
330369 end
331- return edge_list
370+ return
332371end
333372
334373"""
0 commit comments