Skip to content

Commit a9527fa

Browse files
Implement Cholesky dispatch in mlx backend
1 parent 4f55f04 commit a9527fa

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

pytensor/link/mlx/dispatch/slinalg.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,20 @@
33
import mlx.core as mx
44

55
from pytensor.link.mlx.dispatch.basic import mlx_funcify
6-
from pytensor.tensor.slinalg import Solve, SolveTriangular
6+
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
7+
8+
9+
@mlx_funcify.register(Cholesky)
10+
def mlx_funcify_Cholesky(op, node, **kwargs):
11+
lower = op.lower
12+
out_dtype = getattr(mx, node.outputs[0].dtype)
13+
14+
def cholesky(a):
15+
return mx.linalg.cholesky(a, upper=not lower, stream=mx.cpu).astype(
16+
out_dtype, stream=mx.cpu
17+
)
18+
19+
return cholesky
720

821

922
@mlx_funcify.register(Solve)

0 commit comments

Comments
 (0)