From 12ccbb8390cb5103ac5e1d535bc9d506f48430be Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 9 May 2025 17:44:31 +0800 Subject: [PATCH 1/4] Implement simple wrapper Op for `scipy.signal.convolve2d` --- pytensor/tensor/signal/conv.py | 116 ++++++++++++++++++++++++++++++- tests/tensor/signal/test_conv.py | 30 +++++++- 2 files changed, 144 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 9eb9a8abf7..e0fef070f2 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -2,6 +2,7 @@ import numpy as np from numpy import convolve as numpy_convolve +from scipy.signal import convolve2d as scipy_convolve2d from pytensor.gradient import DisconnectedType from pytensor.graph import Apply, Constant @@ -11,7 +12,7 @@ from pytensor.tensor.basic import as_tensor_variable, join, zeros from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import maximum, minimum, switch -from pytensor.tensor.type import vector +from pytensor.tensor.type import matrix, vector from pytensor.tensor.variable import TensorVariable @@ -211,3 +212,116 @@ def convolve1d( full_mode = as_scalar(np.bool_(mode == "full")) return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode)) + + +class Convolve2D(Op): + __props__ = ("mode", "boundary", "fillvalue") + gufunc_signature = "(n,m),(k,l)->(o,p)" + + def __init__( + self, + mode: Literal["full", "valid", "same"] = "full", + boundary: Literal["fill", "wrap", "symm"] = "fill", + fillvalue: float | int = 0, + ): + if mode not in ("full", "valid", "same"): + raise ValueError(f"Invalid mode: {mode}") + if boundary not in ("fill", "wrap", "symm"): + raise ValueError(f"Invalid boundary: {boundary}") + + self.mode = mode + self.boundary = boundary + self.fillvalue = fillvalue + + def make_node(self, in1, in2): + in1, in2 = map(as_tensor_variable, (in1, in2)) + + assert in1.ndim == 2 + assert in2.ndim == 2 + + dtype = upcast(in1.dtype, in2.dtype) + + n, m = in1.type.shape + k, l = in2.type.shape + + if any(x is None for x in (n, m, k, l)): + out_shape = (None, None) + elif self.mode == "full": + out_shape = (n + k - 1, m + l - 1) + elif self.mode == "valid": + out_shape = (n - k + 1, m - l + 1) + else: # mode == "same" + out_shape = (n, m) + + out = matrix(dtype=dtype, shape=out_shape) + return Apply(self, [in1, in2], [out]) + + def perform(self, node, inputs, outputs): + in1, in2 = inputs + outputs[0][0] = scipy_convolve2d( + in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue + ) + + def infer_shape(self, fgraph, node, shapes): + in1_shape, in2_shape = shapes + n, m = in1_shape + k, l = in2_shape + + if self.mode == "full": + shape = (n + k - 1, m + l - 1) + elif self.mode == "valid": + shape = ( + maximum(n, k) - minimum(n, k) + 1, + maximum(m, l) - minimum(m, l) + 1, + ) + else: # self.mode == 'same': + shape = (n, m) + + return [shape] + + def L_op(self, inputs, outputs, output_grads): + raise NotImplementedError + + +def convolve2d( + in1: "TensorLike", + in2: "TensorLike", + mode: Literal["full", "valid", "same"] = "full", + boundary: Literal["fill", "wrap", "symm"] = "fill", + fillvalue: float | int = 0, +) -> TensorVariable: + """Convolve two two-dimensional arrays. + + Convolve in1 and in2, with the output size determined by the mode argument. + + Parameters + ---------- + in1 : (..., N, M) tensor_like + First input. + in2 : (..., K, L) tensor_like + Second input. + mode : {'full', 'valid', 'same'}, optional + A string indicating the size of the output: + - 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1). + - 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1). + - 'same': The output is the same size as in1, centered with respect to the 'full' output. + boundary : {'fill', 'wrap', 'symm'}, optional + A string indicating how to handle boundaries: + - 'fill': Pads the input arrays with fillvalue. + - 'wrap': Circularly wraps the input arrays. + - 'symm': Symmetrically reflects the input arrays. + fillvalue : float or int, optional + The value to use for padding when boundary is 'fill'. Default is 0. + Returns + ------- + out: tensor_variable + The discrete linear convolution of in1 with in2. + + """ + in1 = as_tensor_variable(in1) + in2 = as_tensor_variable(in2) + + blockwise_convolve = Blockwise( + Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue) + ) + return cast(TensorVariable, blockwise_convolve(in1, in2)) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index 4df25cc1ca..b780857a64 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -3,13 +3,14 @@ import numpy as np import pytest from scipy.signal import convolve as scipy_convolve +from scipy.signal import convolve2d as scipy_convolve2d from pytensor import config, function, grad from pytensor.graph.rewriting import rewrite_graph from pytensor.graph.traversal import ancestors, io_toposort from pytensor.tensor import matrix, tensor, vector from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.signal.conv import Convolve1d, convolve1d +from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d from tests import unittest_tools as utt @@ -137,3 +138,30 @@ def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark): @pytest.mark.parametrize("convolve_mode", ["full", "valid"]) def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark): convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark) + + +@pytest.mark.parametrize( + "kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}" +) +@pytest.mark.parametrize( + "data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}" +) +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) +@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"]) +def test_convolve2d(kernel_shape, data_shape, mode, boundary): + data = matrix("data") + kernel = matrix("kernel") + op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0) + + rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode)))) + data_val = rng.normal(size=data_shape).astype(data.dtype) + kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype) + + fn = function([data, kernel], op(data, kernel)) + np.testing.assert_allclose( + fn(data_val, kernel_val), + scipy_convolve2d( + data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0 + ), + rtol=1e-6 if config.floatX == "float32" else 1e-15, + ) From 7a592d4323b231c76427eaf97262b33dd379dce3 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 10 May 2025 09:52:16 +0800 Subject: [PATCH 2/4] Better shape inference --- pytensor/tensor/signal/conv.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index e0fef070f2..9e0456aa76 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -244,15 +244,19 @@ def make_node(self, in1, in2): n, m = in1.type.shape k, l = in2.type.shape - if any(x is None for x in (n, m, k, l)): - out_shape = (None, None) - elif self.mode == "full": - out_shape = (n + k - 1, m + l - 1) + if self.mode == "full": + shape_1 = None if (n is None or k is None) else n + k - 1 + shape_2 = None if (m is None or l is None) else m + l - 1 + elif self.mode == "valid": - out_shape = (n - k + 1, m - l + 1) + shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1 + shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1 + else: # mode == "same" - out_shape = (n, m) + shape_1 = n + shape_2 = m + out_shape = (shape_1, shape_2) out = matrix(dtype=dtype, shape=out_shape) return Apply(self, [in1, in2], [out]) From c2819a8daeb8441c2fac3c69133a4d816277cb58 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 12 Jul 2025 22:52:21 +0800 Subject: [PATCH 3/4] conv2d gradient v0 --- pytensor/tensor/signal/conv.py | 32 ++++++++++++++++++++++++------- tests/tensor/signal/test_conv.py | 33 +++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 9e0456aa76..32db551a49 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -6,6 +6,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph import Apply, Constant +from pytensor.graph.op import Op from pytensor.link.c.op import COp from pytensor.scalar import as_scalar from pytensor.scalar.basic import upcast @@ -220,18 +221,16 @@ class Convolve2D(Op): def __init__( self, - mode: Literal["full", "valid", "same"] = "full", + mode: Literal["full", "valid"] = "full", boundary: Literal["fill", "wrap", "symm"] = "fill", fillvalue: float | int = 0, ): - if mode not in ("full", "valid", "same"): + if mode not in ("full", "valid"): raise ValueError(f"Invalid mode: {mode}") - if boundary not in ("fill", "wrap", "symm"): - raise ValueError(f"Invalid boundary: {boundary}") self.mode = mode - self.boundary = boundary self.fillvalue = fillvalue + self.boundary = boundary def make_node(self, in1, in2): in1, in2 = map(as_tensor_variable, (in1, in2)) @@ -262,8 +261,13 @@ def make_node(self, in1, in2): def perform(self, node, inputs, outputs): in1, in2 = inputs + + # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs): + # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft') + # + # else: outputs[0][0] = scipy_convolve2d( - in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue + in1, in2, mode=self.mode, fillvalue=self.fillvalue, boundary=self.boundary ) def infer_shape(self, fgraph, node, shapes): @@ -284,7 +288,18 @@ def infer_shape(self, fgraph, node, shapes): return [shape] def L_op(self, inputs, outputs, output_grads): - raise NotImplementedError + in1, in2 = inputs + incoming_grads = output_grads[0] + + if self.mode == "full": + prop_dict = self._props_dict() + prop_dict["mode"] = "valid" + conv_valid = type(self)(**prop_dict) + + in1_grad = conv_valid(in2, incoming_grads) + in2_grad = conv_valid(in1, incoming_grads) + + return [in1_grad, in2_grad] def convolve2d( @@ -325,6 +340,9 @@ def convolve2d( in1 = as_tensor_variable(in1) in2 = as_tensor_variable(in2) + # TODO: Handle boundaries symbolically + # TODO: Handle 'same' symbolically + blockwise_convolve = Blockwise( Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue) ) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index b780857a64..7e3b703448 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -163,5 +163,36 @@ def test_convolve2d(kernel_shape, data_shape, mode, boundary): scipy_convolve2d( data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0 ), - rtol=1e-6 if config.floatX == "float32" else 1e-15, + atol=1e-6 if config.floatX == "float32" else 1e-8, ) + + utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val], eps=1e-4) + + +# @pytest.mark.parametrize( +# "data_shape, kernel_shape", [[(10, 1, 8, 8), (3, 1, 3, 3)], # 8x8 grayscale +# [(1000, 1, 8, 8), (3, 1, 1, 3)], # same, but with 1000 images +# [(10, 3, 64, 64), (10, 3, 8, 8)], # 64x64 RGB +# [(1000, 3, 64, 64), (10, 3, 8, 8)], # same, but with 1000 images +# [(3, 100, 100, 100), (250, 100, 50, 50)]], # Very large, deep hidden layer or something +# +# ids=lambda x: f"data_shape={x[0]}, kernel_shape={x[1]}" +# ) +# @pytest.mark.parametrize('func', ['new', 'theano'], ids=['new-impl', 'theano-impl']) +# def test_conv2d_nn_benchmark(data_shape, kernel_shape, func, benchmark): +# import pytensor.tensor as pt +# x = pt.tensor("x", shape=data_shape) +# y = pt.tensor("y", shape=kernel_shape) +# +# if func == 'new': +# out = nn_conv2d(x, y) +# else: +# out = conv2d(input=x, filters=y, border_mode="valid") +# +# rng = np.random.default_rng(38) +# x_test = rng.normal(size=data_shape).astype(x.dtype) +# y_test = rng.normal(size=kernel_shape).astype(y.dtype) +# +# fn = function([x, y], out, trust_input=True) +# +# benchmark(fn, x_test, y_test) From 1e61660ce1c52ccd8d27adfcad4a80b514708175 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 23 Nov 2025 12:43:53 -0600 Subject: [PATCH 4/4] Implement Convolve2d and gradients --- pytensor/gradient.py | 2 +- pytensor/tensor/basic.py | 4 +- pytensor/tensor/signal/conv.py | 215 ++++++++++++++----------------- tests/tensor/signal/test_conv.py | 78 +++++------ 4 files changed, 143 insertions(+), 156 deletions(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 022eba2454..db6d3ec436 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1951,10 +1951,10 @@ def random_projection(): mode_for_cost = mode cost_fn = fn_maker(tensor_pt, cost, name="gradient.py cost", mode=mode_for_cost) - symbolic_grad = grad(cost, tensor_pt, disconnected_inputs="ignore") grad_fn = fn_maker(tensor_pt, symbolic_grad, name="gradient.py symbolic grad") + grad_fn.dprint(print_shape=True) for test_num in range(n_tests): num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps, out_type) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e789659474..b9bf4df4a5 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False): return fill(_model, ret) -def zeros(shape, dtype=None): +def zeros(shape, dtype=None) -> TensorVariable: """Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``.""" if not ( isinstance(shape, np.ndarray | Sequence) @@ -933,7 +933,7 @@ def zeros(shape, dtype=None): return alloc(np.array(0, dtype=dtype), *shape) -def ones(shape, dtype=None): +def ones(shape, dtype=None) -> TensorVariable: """Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``.""" if not ( isinstance(shape, np.ndarray | Sequence) diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 32db551a49..4621a08b1c 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal +from typing import cast as type_cast import numpy as np from numpy import convolve as numpy_convolve @@ -13,7 +14,9 @@ from pytensor.tensor.basic import as_tensor_variable, join, zeros from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import maximum, minimum, switch -from pytensor.tensor.type import matrix, vector +from pytensor.tensor.pad import pad +from pytensor.tensor.subtensor import flip +from pytensor.tensor.type import tensor from pytensor.tensor.variable import TensorVariable @@ -21,54 +24,69 @@ from pytensor.tensor import TensorLike -class Convolve1d(COp): +class AbstractConvolveNd: __props__ = () - gufunc_signature = "(n),(k),()->(o)" + ndim: int + + @property + def gufunc_signature(self): + data_signature = ",".join([f"n{i}" for i in range(self.ndim)]) + kernel_signature = ",".join([f"k{i}" for i in range(self.ndim)]) + output_signature = ",".join([f"o{i}" for i in range(self.ndim)]) + + return f"({data_signature}),({kernel_signature}),()->({output_signature})" def make_node(self, in1, in2, full_mode): in1 = as_tensor_variable(in1) in2 = as_tensor_variable(in2) full_mode = as_scalar(full_mode) - if not (in1.ndim == 1 and in2.ndim == 1): - raise ValueError("Convolution inputs must be vector (ndim=1)") + ndim = self.ndim + if not (in1.ndim == ndim and in2.ndim == self.ndim): + raise ValueError( + f"Convolution inputs must have ndim={ndim}, got: in1={in1.ndim}, in2={in2.ndim}" + ) if not full_mode.dtype == "bool": - raise ValueError("Convolution mode must be a boolean type") + raise ValueError("Convolution full_mode flag must be a boolean type") - dtype = upcast(in1.dtype, in2.dtype) - n = in1.type.shape[0] - k = in2.type.shape[0] match full_mode: case Constant(): static_mode = "full" if full_mode.data else "valid" case _: static_mode = None - if n is None or k is None or static_mode is None: - out_shape = (None,) - elif static_mode == "full": - out_shape = (n + k - 1,) - else: # mode == "valid": - out_shape = (max(n, k) - min(n, k) + 1,) + if static_mode is None: + out_shape = (None,) * ndim + else: + out_shape = [] + # TODO: Raise if static shapes are not valid (one input size doesn't dominate the other) + for n, k in zip(in1.type.shape, in2.type.shape): + if n is None or k is None: + out_shape.append(None) + elif static_mode == "full": + out_shape.append( + n + k - 1, + ) + else: # mode == "valid": + out_shape.append( + max(n, k) - min(n, k) + 1, + ) + out_shape = tuple(out_shape) - out = vector(dtype=dtype, shape=out_shape) - return Apply(self, [in1, in2, full_mode], [out]) + dtype = upcast(in1.dtype, in2.dtype) - def perform(self, node, inputs, outputs): - # We use numpy_convolve as that's what scipy would use if method="direct" was passed. - # And mode != "same", which this Op doesn't cover anyway. - in1, in2, full_mode = inputs - outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid") + out = tensor(dtype=dtype, shape=out_shape) + return Apply(self, [in1, in2, full_mode], [out]) def infer_shape(self, fgraph, node, shapes): _, _, full_mode = node.inputs in1_shape, in2_shape, _ = shapes - n = in1_shape[0] - k = in2_shape[0] - shape_valid = maximum(n, k) - minimum(n, k) + 1 - shape_full = n + k - 1 - shape = switch(full_mode, shape_full, shape_valid) - return [[shape]] + out_shape = [ + switch(full_mode, n + k - 1, maximum(n, k) - minimum(n, k) + 1) + for n, k in zip(in1_shape, in2_shape) + ] + + return [out_shape] def connection_pattern(self, node): return [[True], [True], [False]] @@ -77,22 +95,34 @@ def L_op(self, inputs, outputs, output_grads): in1, in2, full_mode = inputs [grad] = output_grads - n = in1.shape[0] - k = in2.shape[0] + n = in1.shape + k = in2.shape + # Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward) # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve # The expression below is equivalent to ~(full_mode | (k >= n)) - full_mode_in1_bar = ~full_mode & (k < n) + full_mode_in1_bar = ~full_mode & (k < n).any() # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve # The expression below is equivalent to ~(full_mode | (n >= k)) - full_mode_in2_bar = ~full_mode & (n < k) + full_mode_in2_bar = ~full_mode & (n < k).any() return [ - self(grad, in2[::-1], full_mode_in1_bar), - self(grad, in1[::-1], full_mode_in2_bar), + self(grad, flip(in2), full_mode_in1_bar), + self(grad, flip(in1), full_mode_in2_bar), DisconnectedType()(), ] + +class Convolve1d(AbstractConvolveNd, COp): + __props__ = () + ndim = 1 + + def perform(self, node, inputs, outputs): + # We use numpy_convolve as that's what scipy would use if method="direct" was passed. + # And mode != "same", which this Op doesn't cover anyway. + in1, in2, full_mode = inputs + outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid") + def c_code_cache_version(self): return (2,) @@ -212,94 +242,29 @@ def convolve1d( mode = "valid" full_mode = as_scalar(np.bool_(mode == "full")) - return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode)) - - -class Convolve2D(Op): - __props__ = ("mode", "boundary", "fillvalue") - gufunc_signature = "(n,m),(k,l)->(o,p)" + return type_cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode)) - def __init__( - self, - mode: Literal["full", "valid"] = "full", - boundary: Literal["fill", "wrap", "symm"] = "fill", - fillvalue: float | int = 0, - ): - if mode not in ("full", "valid"): - raise ValueError(f"Invalid mode: {mode}") - self.mode = mode - self.fillvalue = fillvalue - self.boundary = boundary - - def make_node(self, in1, in2): - in1, in2 = map(as_tensor_variable, (in1, in2)) - - assert in1.ndim == 2 - assert in2.ndim == 2 - - dtype = upcast(in1.dtype, in2.dtype) - - n, m = in1.type.shape - k, l = in2.type.shape - - if self.mode == "full": - shape_1 = None if (n is None or k is None) else n + k - 1 - shape_2 = None if (m is None or l is None) else m + l - 1 - - elif self.mode == "valid": - shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1 - shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1 - - else: # mode == "same" - shape_1 = n - shape_2 = m - - out_shape = (shape_1, shape_2) - out = matrix(dtype=dtype, shape=out_shape) - return Apply(self, [in1, in2], [out]) +class Convolve2d(AbstractConvolveNd, Op): + __props__ = () + ndim = 2 def perform(self, node, inputs, outputs): - in1, in2 = inputs + in1, in2, full_mode = inputs # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs): # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft') # # else: + # TODO: Why is .item() needed??? outputs[0][0] = scipy_convolve2d( - in1, in2, mode=self.mode, fillvalue=self.fillvalue, boundary=self.boundary + in1, + in2, + mode="full" if full_mode.item() else "valid", ) - def infer_shape(self, fgraph, node, shapes): - in1_shape, in2_shape = shapes - n, m = in1_shape - k, l = in2_shape - - if self.mode == "full": - shape = (n + k - 1, m + l - 1) - elif self.mode == "valid": - shape = ( - maximum(n, k) - minimum(n, k) + 1, - maximum(m, l) - minimum(m, l) + 1, - ) - else: # self.mode == 'same': - shape = (n, m) - - return [shape] - - def L_op(self, inputs, outputs, output_grads): - in1, in2 = inputs - incoming_grads = output_grads[0] - - if self.mode == "full": - prop_dict = self._props_dict() - prop_dict["mode"] = "valid" - conv_valid = type(self)(**prop_dict) - - in1_grad = conv_valid(in2, incoming_grads) - in2_grad = conv_valid(in1, incoming_grads) - return [in1_grad, in2_grad] +blockwise_convolve_2d = Blockwise(Convolve2d()) def convolve2d( @@ -340,10 +305,28 @@ def convolve2d( in1 = as_tensor_variable(in1) in2 = as_tensor_variable(in2) - # TODO: Handle boundaries symbolically - # TODO: Handle 'same' symbolically + if mode == "same": + raise NotImplementedError("same mode not implemented for convolve2d") + + if mode != "valid" and (boundary != "fill" or fillvalue != 0): + # We use a valid convolution on an appropriately padded kernel + *_, k, l = in2.shape + ndim = max(in1.type.ndim, in2.type.ndim) + + pad_width = zeros((ndim, 2), dtype="int64") + pad_width = pad_width[-2, :].set(k - 1) + pad_width = pad_width[-1, :].set(l - 1) + if boundary == "fill": + in1 = pad( + in1, pad_width=pad_width, mode="constant", constant_values=fillvalue + ) + elif boundary == "wrap": + in1 = pad(in1, pad_width=pad_width, mode="wrap") + + elif boundary == "symm": + in1 = pad(in1, pad_width=pad_width, mode="symmetric") - blockwise_convolve = Blockwise( - Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue) - ) - return cast(TensorVariable, blockwise_convolve(in1, in2)) + mode = "valid" + + full_mode = as_scalar(np.bool_(mode == "full")) + return type_cast(TensorVariable, blockwise_convolve_2d(in1, in2, full_mode)) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index 7e3b703448..90751d4673 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -9,6 +9,7 @@ from pytensor.graph.rewriting import rewrite_graph from pytensor.graph.traversal import ancestors, io_toposort from pytensor.tensor import matrix, tensor, vector +from pytensor.tensor.basic import expand_dims from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d from tests import unittest_tools as utt @@ -47,7 +48,7 @@ def test_convolve1d_batch(): res = out.eval({x: x_test, y: y_test}) # Second entry of x, y are just y, x respectively, # so res[0] and res[1] should be identical. - rtol = 1e-6 if config.floatX == "float32" else 1e-15 + rtol = 1e-6 if config.floatX == "float32" else 1e-12 res_np = np.convolve(x_test[0], y_test[0]) np.testing.assert_allclose(res[0], res_np, rtol=rtol) np.testing.assert_allclose(res[1], res_np, rtol=rtol) @@ -101,6 +102,7 @@ def test_convolve1d_valid_grad(static_shape): "local_useless_unbatched_blockwise", ), ) + grad_out.dprint() [conv_node] = [ node for node in io_toposort([larger, smaller], [grad_out]) @@ -146,53 +148,55 @@ def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark): @pytest.mark.parametrize( "data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}" ) -@pytest.mark.parametrize("mode", ["full", "valid", "same"]) -@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"]) -def test_convolve2d(kernel_shape, data_shape, mode, boundary): +@pytest.mark.parametrize("mode", ["full", "valid", "same"][:-1]) +@pytest.mark.parametrize( + "boundary, boundary_kwargs", + [ + ("fill", {"fillvalue": 0}), + ("fill", {"fillvalue": 0.5}), + ("wrap", {}), + ("symm", {}), + ], +) +def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs): data = matrix("data") kernel = matrix("kernel") - op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0) + op = partial(convolve2d, mode=mode, boundary=boundary, **boundary_kwargs) + conv_result = op(data, kernel) + + fn = function([data, kernel], conv_result) rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode)))) data_val = rng.normal(size=data_shape).astype(data.dtype) kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype) - fn = function([data, kernel], op(data, kernel)) np.testing.assert_allclose( fn(data_val, kernel_val), scipy_convolve2d( - data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0 + data_val, kernel_val, mode=mode, boundary=boundary, **boundary_kwargs ), atol=1e-6 if config.floatX == "float32" else 1e-8, ) - utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val], eps=1e-4) - - -# @pytest.mark.parametrize( -# "data_shape, kernel_shape", [[(10, 1, 8, 8), (3, 1, 3, 3)], # 8x8 grayscale -# [(1000, 1, 8, 8), (3, 1, 1, 3)], # same, but with 1000 images -# [(10, 3, 64, 64), (10, 3, 8, 8)], # 64x64 RGB -# [(1000, 3, 64, 64), (10, 3, 8, 8)], # same, but with 1000 images -# [(3, 100, 100, 100), (250, 100, 50, 50)]], # Very large, deep hidden layer or something -# -# ids=lambda x: f"data_shape={x[0]}, kernel_shape={x[1]}" -# ) -# @pytest.mark.parametrize('func', ['new', 'theano'], ids=['new-impl', 'theano-impl']) -# def test_conv2d_nn_benchmark(data_shape, kernel_shape, func, benchmark): -# import pytensor.tensor as pt -# x = pt.tensor("x", shape=data_shape) -# y = pt.tensor("y", shape=kernel_shape) -# -# if func == 'new': -# out = nn_conv2d(x, y) -# else: -# out = conv2d(input=x, filters=y, border_mode="valid") -# -# rng = np.random.default_rng(38) -# x_test = rng.normal(size=data_shape).astype(x.dtype) -# y_test = rng.normal(size=kernel_shape).astype(y.dtype) -# -# fn = function([x, y], out, trust_input=True) -# -# benchmark(fn, x_test, y_test) + utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val]) + + +def test_batched_1d_agrees_with_diagonal_2d(): + data = matrix("data") + kernel_1d = vector("kernel_1d") + kernel_2d = expand_dims(kernel_1d, 0) + + output_1d = convolve1d(data, kernel_1d, mode="valid") + output_2d = convolve2d(data, kernel_2d, mode="valid") + + grad_1d = grad(output_1d.sum(), kernel_1d).ravel() + grad_2d = grad(output_1d.sum(), kernel_1d).ravel() + + fn = function([data, kernel_1d], [output_1d, output_2d, grad_1d, grad_2d]) + + data_val = np.random.normal(size=(10, 8)).astype(config.floatX) + kernel_1d_val = np.random.normal(size=(3,)).astype(config.floatX) + + forward_1d, forward_2d, backward_1d, backward_2d = fn(data_val, kernel_1d_val) + np.testing.assert_allclose(forward_1d, forward_2d) + np.testing.assert_allclose(backward_1d, backward_2d)