77import pytensor
88from pytensor .graph import FunctionGraph , Variable
99from pytensor .npy_2_compat import normalize_axis_tuple
10+ from pytensor .tensor import Any , Constant
1011from pytensor .utils import hash_from_code
1112
1213
@@ -203,8 +204,8 @@ def _parse_gufunc_signature(
203204
204205
205206def _gufunc_to_out_shape (
206- signature : str , shapes : list [tuple [int , ...]]
207- ) -> list [tuple [int , ...]]:
207+ signature : str , shapes : list [tuple [Any , ...]]
208+ ) -> list [tuple [Any , ...]]:
208209 """
209210 Compute the shape of the output of an Op given its gufunc signature and the
210211 shapes of its inputs.
@@ -215,24 +216,47 @@ def _gufunc_to_out_shape(
215216 The gufunc signature of the Op.
216217 eg: "(m,n),(n,p)->(m,p)".
217218
218- shapes : list of tuple of int
219+ shapes : list of tuple of Any
219220 The list of shapes of the inputs.
220221
221222 Returns
222223 -------
223- out_shape : list of tuple of int
224+ out_shape : list of tuple of Any
224225 The list of shapes of the outputs.
226+
227+ Raises
228+ ------
229+ ValueError
230+ If the signature is invalid for the shapes of the inputs.
225231 """
226- parsed = _parse_gufunc_signature (signature )
227- out_shape = []
228- dic = dict ()
229- for i in range (len (parsed [0 ])):
230- for j in range (len (parsed [0 ][i ])):
231- dic [parsed [0 ][i ][j ]] = shapes [i ][j ]
232- for i in range (len (parsed [1 ])):
233- temp_list = [dic [x ] for x in parsed [1 ][i ]]
234- out_shape .append (tuple (temp_list ))
235- return out_shape
232+ input_sig , output_sig = _parse_gufunc_signature (signature )
233+ dim_to_size : dict [str , Any ] = {}
234+ for input_shape , sig in zip (shapes , input_sig , strict = True ):
235+ for size , dim_name in zip (input_shape , sig , strict = True ):
236+ prev_size = dim_to_size .get (dim_name )
237+ if prev_size is None :
238+ dim_to_size [dim_name ] = size
239+ # Prefer constants
240+ elif not isinstance (prev_size , Constant ):
241+ dim_to_size [dim_name ] = size
242+ elif prev_size .data != size :
243+ raise ValueError (
244+ f"Invalid signature { signature } for shapes { shapes } . "
245+ f"Dimension { dim_name } is not consistent across inputs."
246+ )
247+ out_shapes = []
248+ for output_shape in output_sig :
249+ temp_list = []
250+ for dim in output_shape :
251+ if dim not in dim_to_size :
252+ raise ValueError (
253+ f"Invalid signature { signature } for shapes { shapes } . "
254+ f"Dimension { dim } not in input dimensions."
255+ )
256+ else :
257+ temp_list .append (dim_to_size [dim ])
258+ out_shapes .append ((* temp_list ,))
259+ return out_shapes
236260
237261
238262def safe_signature (
0 commit comments