Skip to content

Commit eebbd47

Browse files
cetagostinijessegrabowski
authored andcommitted
I'm going for pizzas, it was an incredible day!
1 parent 8409a27 commit eebbd47

File tree

1 file changed

+82
-23
lines changed

1 file changed

+82
-23
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,108 @@
77
import numpy as np
88

99
def blockwise_conv1d(op, node, **kwargs):
10-
if op.core_op.mode != "valid":
11-
raise NotImplementedError("Only 'valid' mode is supported for conv1d")
12-
# batches_ndim = op.batch_ndim(node)
13-
# if batches_ndim != 1:
14-
# raise NotImplementedError("Only 1D batches are supported for conv1d")
10+
# if op.core_op.mode != "valid":
11+
# raise NotImplementedError("Only 'valid' mode is supported for conv1d")
1512

16-
# _, kernel = node.inputs
17-
# if not all(kernel.type.broadcastable[:batches_ndim]):
18-
# raise NotImplementedError("Only 1D batches are supported for conv1d")
13+
# def inner_f(x, kernel):
14+
# B, T = x.shape
15+
# Bk, K = kernel.shape
16+
# if B != Bk:
17+
# raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
18+
19+
# # 1) Flip each kernel for true convolution
20+
# kernels_flipped = kernel[:, ::-1] # shape (B, K)
21+
22+
# # 2) Reshape input into (N=1, H=T, C_in=B)
23+
# x_in = x.T[None, :, :]
24+
25+
# # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1)
26+
# w = kernels_flipped[:, :, None]
27+
28+
# # 4) Convolve with one group per channel → valid mode
29+
# y = mx.conv1d(
30+
# x_in, w,
31+
# stride=1,
32+
# padding=0,
33+
# dilation=1,
34+
# groups=B
35+
# )
36+
# # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1)
37+
# return y[0].T
1938

20-
def inner_f(x, kernel):
39+
def batched_conv1d(
40+
x: mx.array,
41+
kernels: mx.array,
42+
mode: str = op.core_op.mode,
43+
stride: int = 1,
44+
dilation: int = 1) -> mx.array:
45+
"""
46+
Apply B separate 1D convolutions (full or valid) to B sequences in parallel.
47+
48+
Parameters
49+
----------
50+
x : array of shape (B, T)
51+
B sequences of length T.
52+
kernels : array of shape (B, K)
53+
B kernels of length K.
54+
mode : {"valid", "full"}
55+
"valid" → no padding, output length = T - K + 1
56+
"full" → zero‑pad so output length = T + K - 1
57+
stride : int, convolution stride (default=1)
58+
dilation : int, convolution dilation (default=1)
59+
60+
Returns
61+
-------
62+
out : array of shape (B, L)
63+
where L =
64+
- T - K + 1 if mode="valid"
65+
- T + K - 1 if mode="full"
66+
"""
67+
# --- 1) shape checks ---
2168
B, T = x.shape
22-
Bk, K = kernel.shape
69+
Bk, K = kernels.shape
2370
if B != Bk:
2471
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
2572

26-
# 1) Flip each kernel for true convolution
27-
kernels_flipped = kernel[:, ::-1] # shape (B, K)
73+
# --- 2) flip kernels for convolution ---
74+
kernels_flipped = kernels[:, ::-1] # shape (B, K)
75+
76+
# --- 3) decide padding ---
77+
if mode == "valid":
78+
pad = 0
79+
elif mode == "full":
80+
pad = (K - 1) * dilation
81+
else:
82+
raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'")
2883

29-
# 2) Reshape input into (N=1, H=T, C_in=B)
30-
x_in = x.T[None, :, :]
84+
# --- 4) reshape into MLX conv1d form ---
85+
# input: (N=1, H=T, C_in=B)
86+
x_in = x.T[None, :, :]
3187

32-
# 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1)
33-
w = kernels_flipped[:, :, None]
88+
# weight: (C_out=B, H_f=K, C_in=1)
89+
w = kernels_flipped[:, :, None]
3490

35-
# 4) Convolve with one group per channel → valid mode
91+
# --- 5) run grouped conv1d ---
3692
y = mx.conv1d(
3793
x_in, w,
38-
stride=1,
39-
padding=0,
40-
dilation=1,
94+
stride=stride,
95+
padding=pad,
96+
dilation=dilation,
4197
groups=B
4298
)
43-
# y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1)
99+
# y shape: (1, H_out, B)
100+
101+
# --- 6) return shape (B, H_out) ---
44102
return y[0].T
45-
return inner_f
103+
104+
return batched_conv1d
46105

47106
@mlx_funcify.register(Blockwise)
48107
def funcify_Blockwise(op: Blockwise, node, **kwargs):
49108
if isinstance(op.core_op, Conv1d):
50109
return blockwise_conv1d(op, node, **kwargs)
51110

52-
core_f = mlx_funcify(op.core_op)
111+
core_f = mlx_funcify(op.core_op, node)
53112

54113
def blockwise_f(*inputs):
55114
return blockwise_f(*inputs)

0 commit comments

Comments
 (0)