Skip to content

Commit f6a11df

Browse files
Implement Solve dispatch in mlx backend
1 parent aa6504f commit f6a11df

File tree

5 files changed

+85
-1
lines changed

5 files changed

+85
-1
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
import pytensor.link.mlx.dispatch.signal
1111
import pytensor.link.mlx.dispatch.signal.conv
1212
import pytensor.link.mlx.dispatch.blockwise
13+
import pytensor.link.mlx.dispatch.slinalg
1314
# isort: on
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import warnings
2+
3+
import mlx.core as mx
4+
5+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
6+
from pytensor.tensor.slinalg import Solve
7+
8+
9+
@mlx_funcify.register(Solve)
10+
def mlx_funcify_Solve(op, node, **kwargs):
11+
assume_a = op.assume_a
12+
out_dtype = getattr(mx, node.outputs[0].dtype)
13+
14+
if assume_a != "gen":
15+
warnings.warn(
16+
f"MLX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n"
17+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.",
18+
UserWarning,
19+
)
20+
21+
def solve(a, b):
22+
# MLX only supports solve on CPU
23+
return mx.linalg.solve(a, b, stream=mx.cpu).astype(out_dtype, stream=mx.cpu)
24+
25+
return solve

tests/link/mlx/__init__.py

Whitespace-only changes.

tests/link/mlx/test_math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import mlx as mx
12
import numpy as np
23
import pytest
34

45
import pytensor
56
import pytensor.tensor as pt
67
from pytensor.tensor.math import Argmax, Max
7-
from tests.link.mlx.test_basic import compare_mlx_and_py, mx
8+
from tests.link.mlx.test_basic import compare_mlx_and_py
89

910

1011
def test_dot():

tests/link/mlx/test_slinalg.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import contextlib
2+
3+
import numpy as np
4+
import pytest
5+
6+
import pytensor.tensor as pt
7+
from pytensor import config
8+
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
9+
10+
11+
# mlx complains about useless vmap (when there are no batch dims), so we need to include
12+
# local_remove_useless_blockwise rewrite for these tests
13+
mlx_linalg_mode = mlx_mode.including("blockwise")
14+
15+
16+
@pytest.mark.parametrize("lower", [True, False])
17+
def test_mlx_cholesky(lower):
18+
rng = np.random.default_rng()
19+
20+
A = pt.tensor("A", shape=(5, 5))
21+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
22+
A_val = A_val @ A_val.T
23+
24+
out = pt.linalg.cholesky(A, lower=lower)
25+
26+
compare_mlx_and_py(
27+
[A],
28+
[out],
29+
[A_val],
30+
mlx_mode=mlx_linalg_mode,
31+
)
32+
33+
34+
@pytest.mark.parametrize("assume_a", ["gen", "pos"])
35+
def test_mlx_solve(assume_a):
36+
rng = np.random.default_rng()
37+
38+
A = pt.tensor("A", shape=(5, 5))
39+
b = pt.tensor("B", shape=(5, 5))
40+
41+
out = pt.linalg.solve(A, b, b_ndim=2, assume_a=assume_a)
42+
43+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
44+
A_val = A_val @ A_val.T
45+
46+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
47+
48+
context = (
49+
contextlib.suppress()
50+
if assume_a == "gen"
51+
else pytest.warns(
52+
UserWarning, match=f"MLX solve does not support assume_a={assume_a}"
53+
)
54+
)
55+
56+
with context:
57+
compare_mlx_and_py([A, b], [out], [A_val, b_val], mlx_mode=mlx_linalg_mode)

0 commit comments

Comments
 (0)