Skip to content

Commit 5507d57

Browse files
Implement MatrixInv and MatrixPinv dispatch in mlx backend
1 parent d26ab32 commit 5507d57

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

pytensor/link/mlx/dispatch/nlinalg.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4-
from pytensor.tensor.nlinalg import SVD, KroneckerProduct
4+
from pytensor.tensor.nlinalg import SVD, KroneckerProduct, MatrixInverse, MatrixPinv
55

66

77
@mlx_funcify.register(SVD)
@@ -45,3 +45,23 @@ def kron(a, b):
4545
return mx.kron(a, b, stream=stream).astype(mx_otype, stream=stream)
4646

4747
return kron
48+
49+
50+
@mlx_funcify.register(MatrixInverse)
51+
def mlx_funcify_MatrixInverse(op, node, **kwargs):
52+
otype = getattr(mx, node.outputs[0].dtype)
53+
54+
def inv(x):
55+
return mx.linalg.inv(x, stream=mx.cpu).astype(otype, stream=mx.cpu)
56+
57+
return inv
58+
59+
60+
@mlx_funcify.register(MatrixPinv)
61+
def mlx_funcify_MatrixPinv(op, node, **kwargs):
62+
otype = getattr(mx, node.outputs[0].dtype)
63+
64+
def pinv(x):
65+
return mx.linalg.pinv(x, stream=mx.cpu).astype(otype, stream=mx.cpu)
66+
67+
return pinv

tests/link/mlx/test_nlinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,21 @@ def test_mlx_kron():
4141
[A_val, B_val],
4242
mlx_mode=mlx_linalg_mode,
4343
)
44+
45+
46+
@pytest.mark.parametrize("op", [pt.linalg.inv, pt.linalg.pinv], ids=["inv", "pinv"])
47+
def test_mlx_inv(op):
48+
rng = np.random.default_rng()
49+
50+
A = pt.matrix(name="A")
51+
A_val = rng.normal(size=(3, 3)).astype(config.floatX)
52+
A_val = A_val @ A_val.T
53+
54+
out = op(A)
55+
56+
compare_mlx_and_py(
57+
[A],
58+
[out],
59+
[A_val],
60+
mlx_mode=mlx_linalg_mode,
61+
)

0 commit comments

Comments
 (0)