We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4f55f04 commit a9527faCopy full SHA for a9527fa
pytensor/link/mlx/dispatch/slinalg.py
@@ -3,7 +3,20 @@
3
import mlx.core as mx
4
5
from pytensor.link.mlx.dispatch.basic import mlx_funcify
6
-from pytensor.tensor.slinalg import Solve, SolveTriangular
+from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
7
+
8
9
+@mlx_funcify.register(Cholesky)
10
+def mlx_funcify_Cholesky(op, node, **kwargs):
11
+ lower = op.lower
12
+ out_dtype = getattr(mx, node.outputs[0].dtype)
13
14
+ def cholesky(a):
15
+ return mx.linalg.cholesky(a, upper=not lower, stream=mx.cpu).astype(
16
+ out_dtype, stream=mx.cpu
17
+ )
18
19
+ return cholesky
20
21
22
@mlx_funcify.register(Solve)
0 commit comments