Skip to content

Commit 925f91b

Browse files
cetagostiniricardoV94jessegrabowski
committed
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
A shout out for the fathers of the day! Co-Authored-By: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Co-Authored-By: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
1 parent cf88a8a commit 925f91b

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,24 @@ def batched_conv1d(
105105

106106
@mlx_funcify.register(Blockwise)
107107
def funcify_Blockwise(op: Blockwise, node, **kwargs):
108+
# 1) If it's a Conv1d Blockwise, use the custom implementation
108109
if isinstance(op.core_op, Conv1d):
109110
return blockwise_conv1d(op, node, **kwargs)
110-
111-
core_f = mlx_funcify(op.core_op, node)
112111

113-
def blockwise_f(*inputs):
114-
return blockwise_f(*inputs)
112+
# 2) Otherwise, get the core python function for this Blockwise
115113
core_node = op._create_dummy_core_node(node.inputs)
116-
117114
core_f = mlx_funcify(op.core_op, core_node)
118-
blockwise_f = core_f
119-
for i in range(op.batch_ndim(node)):
120-
blockwise_f = mx.vmap(blockwise_f)
121115

116+
# 3) Determine how many inputs correspond to batch dimensions
117+
n_batch = op.batch_ndim(node)
118+
119+
# 4) Build in_axes: map only the first n_batch args, keep the rest static
120+
in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs)))
121+
122+
# 5) Vectorize (vmap) with in_axes
123+
blockwise_f = mx.vmap(core_f, in_axes=in_axes)
124+
125+
# 6) Return the mapped function
122126
def blockwise_fun(*inputs):
123127
return blockwise_f(*inputs)
124128

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def careduce(x):
4343
axis = list(range(x.ndim))
4444

4545
if acc_dtype is None:
46-
acc_dtype = x.dtype.type
46+
acc_dtype = x.dtype
4747

4848
if op_nfunc_spec:
4949
mlx_op = getattr(mx, op_nfunc_spec[0])
5050
return mlx_op(x, axis=axis)
51-
return mlx_op(x, axis=axis).astype(acc_dtype)
51+
# return mlx_op(x, axis=axis).astype(acc_dtype)
5252

5353
# The PyTensor `Op` didn't tell us which NumPy equivalent to use (or
5454
# there isn't one), so we use this fallback approach

0 commit comments

Comments
 (0)