Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph, Op
from pytensor.graph.basic import Constant
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
Expand Down Expand Up @@ -910,6 +910,64 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times along an axis
where it has size 1, replace with a repeat operation which is more efficient.

Examples
--------
concatenate([x[None], x[None], x[None]], axis=0) -> repeat(x[None], 3, axis=0)
"""
# Extract axis and the tensors being joined
axis_sym, *tensors = node.inputs

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return None

# Extract (and normalize) axis as Python int
try:
axis_val = int(get_scalar_constant_value(axis_sym, only_process_constants=True))
except NotScalarConstantError:
return None

# Get first tensor and check if ndim is known
first = tensors[0]
ndim = first.ndim
if ndim is None:
return None

# Normalize negative axes (e.g., -1 -> ndim-1)
axis_val = axis_val % ndim

# All inputs must be structurally the same tensor
# Use equal_computations to check structural equality, not symbolic ==
for t in tensors[1:]:
if not equal_computations([t], [first]):
return None

# Only apply when size along join axis is statically 1
# (e.g., x[None] has a guaranteed 1 at that axis)
shp = first.type.shape # tuple of ints/None
if shp is None or axis_val >= len(shp) or shp[axis_val] != 1:
return None

# Replace with repeat operation
from pytensor.tensor.extra_ops import repeat

n = len(tensors)
result = repeat(first, n, axis=axis_val)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)

return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
100 changes: 100 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,106 @@ def test_local_join_1():
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)

This optimization applies when joining the same tensor multiple times
along an axis where it has size 1 (e.g., after ExpandDims).
"""

# Test with vector expanded to (1, n) - concatenate along axis 0
x = vector("x")
x_expanded = x[None] # Shape: (1, n)
s = join(0, x_expanded, x_expanded, x_expanded) # Shape: (3, n)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Alloc (repeat with scalar repeats becomes Alloc)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - add dimension and concatenate along new axis
a = matrix("a") # Shape: (m, n)
a_expanded = a[None, :, :] # Shape: (1, m, n)
s = join(0, a_expanded, a_expanded, a_expanded, a_expanded) # Shape: (4, m, n)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.array([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test with matrix - expand along axis 1 and concatenate
a_expanded_ax1 = a[:, None, :] # Shape: (m, 1, n)
s = join(1, a_expanded_ax1, a_expanded_ax1) # Shape: (m, 2, n)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.array([[[1.0, 2.0], [1.0, 2.0]], [[3.0, 4.0], [3.0, 4.0]]])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1

# Test that it does NOT apply when tensors are different
y = vector("y")
s = join(0, x[None], y[None])
f = function([x, y], s, mode=rewrite_mode)

test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([[1.0, 2.0], [3.0, 4.0]])
assert np.allclose(result, expected)

# Join should still be present (not optimized)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test that it does NOT apply when tensor doesn't have size 1 along join axis
# (regular concatenation without ExpandDims)
s = join(0, x, x, x) # Shape: (3n,) not using ExpandDims
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)

# Join should still be present (optimization doesn't apply)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x[None], x[None], x[None], x[None], x[None])
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([[1.0, 2.0]] * 5, dtype=config.floatX)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Alloc)]) >= 1


def test_local_join_empty():
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
Expand Down