Skip to content
34 changes: 33 additions & 1 deletion pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.extra_ops import broadcast_arrays, repeat
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
Expand Down Expand Up @@ -910,6 +910,38 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times,
replace with a single repeat operation which is more efficient.

Examples
--------
concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0)
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return

# Replace with repeat operation
result = repeat(tensors[0], len(tensors), axis)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)

return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
81 changes: 81 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
tile,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import Repeat
from pytensor.tensor.math import (
add,
bitwise_and,
Expand Down Expand Up @@ -1247,6 +1248,86 @@ def test_local_join_1():
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)"""

# Test with vector - concatenate same vector 3 times along axis 0
x = vector("x")
s = join(0, x, x, x)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Repeat
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1

# Test with matrix - concatenate same matrix along axis 0
a = matrix("a")
s = join(0, a, a, a, a)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1

# Test with matrix - concatenate along axis 1
s = join(1, a, a)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.hstack([test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1

# Test that it does NOT apply when tensors are different
b = matrix("b")
s = join(0, a, b)
f = function([a, b], s, mode=rewrite_mode)

test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX)
test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX)
result = f(test_mat1, test_mat2)
expected = np.vstack([test_mat1, test_mat2])
assert np.allclose(result, expected)

# Join should still be present (not optimized to Repeat)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x, x, x, x, x)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.tile(test_val, 5)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1


def test_local_join_empty():
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
Expand Down
Loading