Skip to content

Scans are never constant-folded #1688

@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt

x0 = pt.scalar("x")
xs, _ = pytensor.scan(lambda x: x+1, outputs_info=[x0], n_steps=4)
fn = pytensor.function([x0], xs)
fn.dprint()  # Scan still in the graph

This happens because Alloc never constant-folds if used by a SetSubtensor, as most times we want to write in place (and we can't write in-place of constants). But when the whole chain could ultimately be constant-folded (as here), this is wasteful.

The logic for whether to constant-fold based on the graph or not should be the responsibility of the constant-fold rewrite, not the Op. Right now it's implemented here.

AllocEmpty never constant_folds:

def do_constant_folding(self, fgraph, node):
return False

And Alloc has this logic in it:

def do_constant_folding(self, fgraph, node):
clients = fgraph.clients[node.outputs[0]]
if not clients:
return False
for client, idx in clients:
client_op = client.op
if isinstance(client_op, Output):
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
# Op's through which Alloc can be lifted
elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
return False
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
return False
elif (
# The following ops work inplace of their input id 0.
idx == 0
and isinstance(
client_op,
pytensor.tensor.subtensor.IncSubtensor
| pytensor.tensor.subtensor.AdvancedIncSubtensor1
| pytensor.tensor.subtensor.AdvancedIncSubtensor
| pytensor.tensor.blas.Gemv
| pytensor.tensor.blas_c.CGemv
| pytensor.tensor.blas.Ger
| pytensor.tensor.blas_c.CGer,
)
):
# Ops that will work inplace on the Alloc. So if they
# get constant_folded, they would copy the constant
# and this is less efficient.
# Not doing the constant folding could also lower the
# peak memory use, as the "constant" won't always exist.
return False
return True

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions