Skip to content

Commit 4dfd488

Browse files
Implement SVD dispatch in mlx backend
1 parent a9527fa commit 4dfd488

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
import pytensor.link.mlx.dispatch.signal.conv
1212
import pytensor.link.mlx.dispatch.blockwise
1313
import pytensor.link.mlx.dispatch.slinalg
14+
import pytensor.link.mlx.dispatch.nlinalg
1415
# isort: on
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import mlx.core as mx
2+
3+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4+
from pytensor.tensor.nlinalg import (
5+
SVD,
6+
)
7+
8+
9+
@mlx_funcify.register(SVD)
10+
def mlx_funcify_SVD(op, node, **kwargs):
11+
full_matrices = op.full_matrices
12+
compute_uv = op.compute_uv
13+
otype = (
14+
getattr(mx, node.outputs[0].dtype)
15+
if not compute_uv
16+
else [getattr(mx, output.dtype) for output in node.outputs]
17+
)
18+
19+
if not full_matrices:
20+
raise TypeError("full_matrices=False is not supported in the mlx backend.")
21+
22+
def svd_S_only(x):
23+
return mx.linalg.svd(x, compute_uv=False, stream=mx.cpu).astype(
24+
otype, stream=mx.cpu
25+
)
26+
27+
def svd_full(x):
28+
outputs = mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
29+
return (
30+
output.astype(typ, stream=mx.cpu)
31+
for output, typ in zip(outputs, otype, strict=True)
32+
)
33+
34+
if compute_uv:
35+
return svd_full
36+
else:
37+
return svd_S_only

tests/link/mlx/test_nlinalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor import config
6+
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
7+
8+
9+
mlx_linalg_mode = mlx_mode.including("blockwise")
10+
11+
12+
@pytest.mark.parametrize("compute_uv", [True, False])
13+
def test_mlx_svd(compute_uv):
14+
rng = np.random.default_rng()
15+
16+
A = pt.matrix(name="X")
17+
A_val = rng.normal(size=(3, 3)).astype(config.floatX)
18+
A_val = A_val @ A_val.T
19+
20+
out = pt.linalg.svd(A, compute_uv=compute_uv)
21+
22+
compare_mlx_and_py(
23+
[A],
24+
out,
25+
[A_val],
26+
mlx_mode=mlx_linalg_mode,
27+
)

0 commit comments

Comments
 (0)