44from pytensor .tensor .blockwise import Blockwise
55from pytensor .tensor .signal .conv import Conv1d
66
7- import numpy as np
87
98def blockwise_conv1d (op , node , ** kwargs ):
10- # if op.core_op.mode != "valid":
11- # raise NotImplementedError("Only 'valid' mode is supported for conv1d")
12-
13- # def inner_f(x, kernel):
14- # B, T = x.shape
15- # Bk, K = kernel.shape
16- # if B != Bk:
17- # raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
18-
19- # # 1) Flip each kernel for true convolution
20- # kernels_flipped = kernel[:, ::-1] # shape (B, K)
21-
22- # # 2) Reshape input into (N=1, H=T, C_in=B)
23- # x_in = x.T[None, :, :]
24-
25- # # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1)
26- # w = kernels_flipped[:, :, None]
27-
28- # # 4) Convolve with one group per channel → valid mode
29- # y = mx.conv1d(
30- # x_in, w,
31- # stride=1,
32- # padding=0,
33- # dilation=1,
34- # groups=B
35- # )
36- # # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1)
37- # return y[0].T
38-
9+ """
10+ Custom implementation of Blockwise.conv1d for MLX.
11+ """
12+
3913 def batched_conv1d (
40- x : mx .array ,
41- kernels : mx .array ,
42- mode : str = op .core_op .mode ,
43- stride : int = 1 ,
44- dilation : int = 1 ) -> mx .array :
14+ x : mx .array ,
15+ kernels : mx .array ,
16+ mode : str = op .core_op .mode ,
17+ stride : int = 1 ,
18+ dilation : int = 1 ,
19+ ) -> mx .array :
4520 """
4621 Apply B separate 1D convolutions (full or valid) to B sequences in parallel.
4722
@@ -53,14 +28,14 @@ def batched_conv1d(
5328 B kernels of length K.
5429 mode : {"valid", "full"}
5530 "valid" → no padding, output length = T - K + 1
56- "full" → zero‑ pad so output length = T + K - 1
31+ "full" → zero- pad so output length = T + K - 1
5732 stride : int, convolution stride (default=1)
5833 dilation : int, convolution dilation (default=1)
5934
6035 Returns
6136 -------
6237 out : array of shape (B, L)
63- where L =
38+ where L =
6439 - T - K + 1 if mode="valid"
6540 - T + K - 1 if mode="full"
6641 """
@@ -89,20 +64,15 @@ def batched_conv1d(
8964 w = kernels_flipped [:, :, None ]
9065
9166 # --- 5) run grouped conv1d ---
92- y = mx .conv1d (
93- x_in , w ,
94- stride = stride ,
95- padding = pad ,
96- dilation = dilation ,
97- groups = B
98- )
67+ y = mx .conv1d (x_in , w , stride = stride , padding = pad , dilation = dilation , groups = B )
9968 # y shape: (1, H_out, B)
10069
10170 # --- 6) return shape (B, H_out) ---
10271 return y [0 ].T
10372
10473 return batched_conv1d
10574
75+
10676@mlx_funcify .register (Blockwise )
10777def funcify_Blockwise (op : Blockwise , node , ** kwargs ):
10878 # 1) If it's a Conv1d Blockwise, use the custom implementation
0 commit comments