Skip to content

Commit 1d93510

Browse files
committed
Refactor infer_shape methods to utilize _gufunc_to_out_shape for output shape computation
1 parent b75c18f commit 1d93510

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.tensor.basic import as_tensor_variable, diagonal
1818
from pytensor.tensor.blockwise import Blockwise
1919
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
20+
from pytensor.tensor.utils import _gufunc_to_out_shape
2021

2122

2223
class MatrixPinv(Op):
@@ -63,7 +64,7 @@ def L_op(self, inputs, outputs, g_outputs):
6364
return [grad]
6465

6566
def infer_shape(self, fgraph, node, shapes):
66-
return [list(reversed(shapes[0]))]
67+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
6768

6869

6970
def pinv(x, hermitian=False):
@@ -156,7 +157,7 @@ def R_op(self, inputs, eval_points):
156157
return [-matrix_dot(xi, ev, xi)]
157158

158159
def infer_shape(self, fgraph, node, shapes):
159-
return shapes
160+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
160161

161162

162163
inv = matrix_inverse = Blockwise(MatrixInverse())
@@ -225,7 +226,7 @@ def grad(self, inputs, g_outputs):
225226
return [gz * self(x) * matrix_inverse(x).T]
226227

227228
def infer_shape(self, fgraph, node, shapes):
228-
return [()]
229+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
229230

230231
def __str__(self):
231232
return "Det"
@@ -259,7 +260,7 @@ def perform(self, node, inputs, outputs):
259260
raise ValueError("Failed to compute determinant", x) from e
260261

261262
def infer_shape(self, fgraph, node, shapes):
262-
return [(), ()]
263+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
263264

264265
def __str__(self):
265266
return "SLogDet"
@@ -317,8 +318,7 @@ def perform(self, node, inputs, outputs):
317318
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
318319

319320
def infer_shape(self, fgraph, node, shapes):
320-
n = shapes[0][0]
321-
return [(n,), (n, n)]
321+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
322322

323323

324324
eig = Blockwise(Eig())
@@ -619,16 +619,7 @@ def perform(self, node, inputs, outputs):
619619
s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
620620

621621
def infer_shape(self, fgraph, node, shapes):
622-
(x_shape,) = shapes
623-
M, N = x_shape
624-
K = ptm.minimum(M, N)
625-
s_shape = (K,)
626-
if self.compute_uv:
627-
u_shape = (M, M) if self.full_matrices else (M, K)
628-
vt_shape = (N, N) if self.full_matrices else (K, N)
629-
return [u_shape, s_shape, vt_shape]
630-
else:
631-
return [s_shape]
622+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
632623

633624
def L_op(
634625
self,

pytensor/tensor/slinalg.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytensor.tensor.nlinalg import kron, matrix_dot
2121
from pytensor.tensor.shape import reshape
2222
from pytensor.tensor.type import matrix, tensor, vector
23+
from pytensor.tensor.utils import _gufunc_to_out_shape
2324
from pytensor.tensor.variable import TensorVariable
2425

2526

@@ -51,7 +52,7 @@ def __init__(
5152
self.destroy_map = {0: [0]}
5253

5354
def infer_shape(self, fgraph, node, shapes):
54-
return [shapes[0]]
55+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
5556

5657
def make_node(self, x):
5758
x = as_tensor_variable(x)
@@ -269,13 +270,7 @@ def make_node(self, A, b):
269270
return Apply(self, [A, b], [x])
270271

271272
def infer_shape(self, fgraph, node, shapes):
272-
Ashape, Bshape = shapes
273-
rows = Ashape[1]
274-
if len(Bshape) == 1:
275-
return [(rows,)]
276-
else:
277-
cols = Bshape[1]
278-
return [(rows, cols)]
273+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
279274

280275
def L_op(self, inputs, outputs, output_gradients):
281276
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
@@ -891,7 +886,7 @@ def perform(self, node, inputs, output_storage):
891886
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
892887

893888
def infer_shape(self, fgraph, node, shapes):
894-
return [shapes[0]]
889+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
895890

896891
def grad(self, inputs, output_grads):
897892
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
@@ -963,7 +958,7 @@ def perform(self, node, inputs, output_storage):
963958
)
964959

965960
def infer_shape(self, fgraph, node, shapes):
966-
return [shapes[0]]
961+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
967962

968963
def grad(self, inputs, output_grads):
969964
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
@@ -1083,7 +1078,7 @@ def perform(self, node, inputs, output_storage):
10831078
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
10841079

10851080
def infer_shape(self, fgraph, node, shapes):
1086-
return [shapes[0]]
1081+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
10871082

10881083
def grad(self, inputs, output_grads):
10891084
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
@@ -1181,8 +1176,7 @@ def grad(self, inputs, gout):
11811176
return [gout[0][slc] for slc in slices]
11821177

11831178
def infer_shape(self, fgraph, nodes, shapes):
1184-
first, second = zip(*shapes, strict=True)
1185-
return [(pt.add(*first), pt.add(*second))]
1179+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
11861180

11871181
def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
11881182
if len(matrices) != self.n_inputs:

pytensor/tensor/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,39 @@ def _parse_gufunc_signature(
202202
)
203203

204204

205+
def _gufunc_to_out_shape(
206+
signature: str, shapes: list[tuple[int, ...]]
207+
) -> list[tuple[int, ...]]:
208+
"""
209+
Compute the shape of the output of an Op given its gufunc signature and the
210+
shapes of its inputs.
211+
212+
Parameters
213+
----------
214+
signature : str
215+
The gufunc signature of the Op.
216+
eg: "(m,n),(n,p)->(m,p)".
217+
218+
shapes : list of tuple of int
219+
The list of shapes of the inputs.
220+
221+
Returns
222+
-------
223+
out_shape : list of tuple of int
224+
The list of shapes of the outputs.
225+
"""
226+
parsed = _parse_gufunc_signature(signature)
227+
out_shape = []
228+
dic = dict()
229+
for i in range(len(parsed[0])):
230+
for j in range(len(parsed[0][i])):
231+
dic[parsed[0][i][j]] = shapes[i][j]
232+
for i in range(len(parsed[1])):
233+
temp_list = [dic[x] for x in parsed[1][i]]
234+
out_shape.append(tuple(temp_list))
235+
return out_shape
236+
237+
205238
def safe_signature(
206239
core_inputs_ndim: Sequence[int],
207240
core_outputs_ndim: Sequence[int],

0 commit comments

Comments
 (0)