Skip to content

Commit e5e0af6

Browse files
cetagostinijessegrabowski
authored andcommitted
pre-commit
1 parent 925f91b commit e5e0af6

File tree

1 file changed

+14
-44
lines changed

1 file changed

+14
-44
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,19 @@
44
from pytensor.tensor.blockwise import Blockwise
55
from pytensor.tensor.signal.conv import Conv1d
66

7-
import numpy as np
87

98
def blockwise_conv1d(op, node, **kwargs):
10-
# if op.core_op.mode != "valid":
11-
# raise NotImplementedError("Only 'valid' mode is supported for conv1d")
12-
13-
# def inner_f(x, kernel):
14-
# B, T = x.shape
15-
# Bk, K = kernel.shape
16-
# if B != Bk:
17-
# raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
18-
19-
# # 1) Flip each kernel for true convolution
20-
# kernels_flipped = kernel[:, ::-1] # shape (B, K)
21-
22-
# # 2) Reshape input into (N=1, H=T, C_in=B)
23-
# x_in = x.T[None, :, :]
24-
25-
# # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1)
26-
# w = kernels_flipped[:, :, None]
27-
28-
# # 4) Convolve with one group per channel → valid mode
29-
# y = mx.conv1d(
30-
# x_in, w,
31-
# stride=1,
32-
# padding=0,
33-
# dilation=1,
34-
# groups=B
35-
# )
36-
# # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1)
37-
# return y[0].T
38-
9+
"""
10+
Custom implementation of Blockwise.conv1d for MLX.
11+
"""
12+
3913
def batched_conv1d(
40-
x: mx.array,
41-
kernels: mx.array,
42-
mode: str = op.core_op.mode,
43-
stride: int = 1,
44-
dilation: int = 1) -> mx.array:
14+
x: mx.array,
15+
kernels: mx.array,
16+
mode: str = op.core_op.mode,
17+
stride: int = 1,
18+
dilation: int = 1,
19+
) -> mx.array:
4520
"""
4621
Apply B separate 1D convolutions (full or valid) to B sequences in parallel.
4722
@@ -53,14 +28,14 @@ def batched_conv1d(
5328
B kernels of length K.
5429
mode : {"valid", "full"}
5530
"valid" → no padding, output length = T - K + 1
56-
"full" → zeropad so output length = T + K - 1
31+
"full" → zero-pad so output length = T + K - 1
5732
stride : int, convolution stride (default=1)
5833
dilation : int, convolution dilation (default=1)
5934
6035
Returns
6136
-------
6237
out : array of shape (B, L)
63-
where L =
38+
where L =
6439
- T - K + 1 if mode="valid"
6540
- T + K - 1 if mode="full"
6641
"""
@@ -89,20 +64,15 @@ def batched_conv1d(
8964
w = kernels_flipped[:, :, None]
9065

9166
# --- 5) run grouped conv1d ---
92-
y = mx.conv1d(
93-
x_in, w,
94-
stride=stride,
95-
padding=pad,
96-
dilation=dilation,
97-
groups=B
98-
)
67+
y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B)
9968
# y shape: (1, H_out, B)
10069

10170
# --- 6) return shape (B, H_out) ---
10271
return y[0].T
10372

10473
return batched_conv1d
10574

75+
10676
@mlx_funcify.register(Blockwise)
10777
def funcify_Blockwise(op: Blockwise, node, **kwargs):
10878
# 1) If it's a Conv1d Blockwise, use the custom implementation

0 commit comments

Comments
 (0)