|
7 | 7 | import numpy as np |
8 | 8 |
|
9 | 9 | def blockwise_conv1d(op, node, **kwargs): |
10 | | - if op.core_op.mode != "valid": |
11 | | - raise NotImplementedError("Only 'valid' mode is supported for conv1d") |
12 | | - # batches_ndim = op.batch_ndim(node) |
13 | | - # if batches_ndim != 1: |
14 | | - # raise NotImplementedError("Only 1D batches are supported for conv1d") |
| 10 | + # if op.core_op.mode != "valid": |
| 11 | + # raise NotImplementedError("Only 'valid' mode is supported for conv1d") |
15 | 12 |
|
16 | | - # _, kernel = node.inputs |
17 | | - # if not all(kernel.type.broadcastable[:batches_ndim]): |
18 | | - # raise NotImplementedError("Only 1D batches are supported for conv1d") |
| 13 | + # def inner_f(x, kernel): |
| 14 | + # B, T = x.shape |
| 15 | + # Bk, K = kernel.shape |
| 16 | + # if B != Bk: |
| 17 | + # raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") |
| 18 | + |
| 19 | + # # 1) Flip each kernel for true convolution |
| 20 | + # kernels_flipped = kernel[:, ::-1] # shape (B, K) |
| 21 | + |
| 22 | + # # 2) Reshape input into (N=1, H=T, C_in=B) |
| 23 | + # x_in = x.T[None, :, :] |
| 24 | + |
| 25 | + # # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) |
| 26 | + # w = kernels_flipped[:, :, None] |
| 27 | + |
| 28 | + # # 4) Convolve with one group per channel → valid mode |
| 29 | + # y = mx.conv1d( |
| 30 | + # x_in, w, |
| 31 | + # stride=1, |
| 32 | + # padding=0, |
| 33 | + # dilation=1, |
| 34 | + # groups=B |
| 35 | + # ) |
| 36 | + # # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) |
| 37 | + # return y[0].T |
19 | 38 |
|
20 | | - def inner_f(x, kernel): |
| 39 | + def batched_conv1d( |
| 40 | + x: mx.array, |
| 41 | + kernels: mx.array, |
| 42 | + mode: str = op.core_op.mode, |
| 43 | + stride: int = 1, |
| 44 | + dilation: int = 1) -> mx.array: |
| 45 | + """ |
| 46 | + Apply B separate 1D convolutions (full or valid) to B sequences in parallel. |
| 47 | +
|
| 48 | + Parameters |
| 49 | + ---------- |
| 50 | + x : array of shape (B, T) |
| 51 | + B sequences of length T. |
| 52 | + kernels : array of shape (B, K) |
| 53 | + B kernels of length K. |
| 54 | + mode : {"valid", "full"} |
| 55 | + "valid" → no padding, output length = T - K + 1 |
| 56 | + "full" → zero‑pad so output length = T + K - 1 |
| 57 | + stride : int, convolution stride (default=1) |
| 58 | + dilation : int, convolution dilation (default=1) |
| 59 | +
|
| 60 | + Returns |
| 61 | + ------- |
| 62 | + out : array of shape (B, L) |
| 63 | + where L = |
| 64 | + - T - K + 1 if mode="valid" |
| 65 | + - T + K - 1 if mode="full" |
| 66 | + """ |
| 67 | + # --- 1) shape checks --- |
21 | 68 | B, T = x.shape |
22 | | - Bk, K = kernel.shape |
| 69 | + Bk, K = kernels.shape |
23 | 70 | if B != Bk: |
24 | 71 | raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") |
25 | 72 |
|
26 | | - # 1) Flip each kernel for true convolution |
27 | | - kernels_flipped = kernel[:, ::-1] # shape (B, K) |
| 73 | + # --- 2) flip kernels for convolution --- |
| 74 | + kernels_flipped = kernels[:, ::-1] # shape (B, K) |
| 75 | + |
| 76 | + # --- 3) decide padding --- |
| 77 | + if mode == "valid": |
| 78 | + pad = 0 |
| 79 | + elif mode == "full": |
| 80 | + pad = (K - 1) * dilation |
| 81 | + else: |
| 82 | + raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") |
28 | 83 |
|
29 | | - # 2) Reshape input into (N=1, H=T, C_in=B) |
30 | | - x_in = x.T[None, :, :] |
| 84 | + # --- 4) reshape into MLX conv1d form --- |
| 85 | + # input: (N=1, H=T, C_in=B) |
| 86 | + x_in = x.T[None, :, :] |
31 | 87 |
|
32 | | - # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1) |
33 | | - w = kernels_flipped[:, :, None] |
| 88 | + # weight: (C_out=B, H_f=K, C_in=1) |
| 89 | + w = kernels_flipped[:, :, None] |
34 | 90 |
|
35 | | - # 4) Convolve with one group per channel → valid mode |
| 91 | + # --- 5) run grouped conv1d --- |
36 | 92 | y = mx.conv1d( |
37 | 93 | x_in, w, |
38 | | - stride=1, |
39 | | - padding=0, |
40 | | - dilation=1, |
| 94 | + stride=stride, |
| 95 | + padding=pad, |
| 96 | + dilation=dilation, |
41 | 97 | groups=B |
42 | 98 | ) |
43 | | - # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1) |
| 99 | + # y shape: (1, H_out, B) |
| 100 | + |
| 101 | + # --- 6) return shape (B, H_out) --- |
44 | 102 | return y[0].T |
45 | | - return inner_f |
| 103 | + |
| 104 | + return batched_conv1d |
46 | 105 |
|
47 | 106 | @mlx_funcify.register(Blockwise) |
48 | 107 | def funcify_Blockwise(op: Blockwise, node, **kwargs): |
49 | 108 | if isinstance(op.core_op, Conv1d): |
50 | 109 | return blockwise_conv1d(op, node, **kwargs) |
51 | 110 |
|
52 | | - core_f = mlx_funcify(op.core_op) |
| 111 | + core_f = mlx_funcify(op.core_op, node) |
53 | 112 |
|
54 | 113 | def blockwise_f(*inputs): |
55 | 114 | return blockwise_f(*inputs) |
|
0 commit comments