Skip to content

Commit 4f55f04

Browse files
Implement SolveTriangular dispatch in mlx backend
1 parent f6a11df commit 4f55f04

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

pytensor/link/mlx/dispatch/slinalg.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import mlx.core as mx
44

55
from pytensor.link.mlx.dispatch.basic import mlx_funcify
6-
from pytensor.tensor.slinalg import Solve
6+
from pytensor.tensor.slinalg import Solve, SolveTriangular
77

88

99
@mlx_funcify.register(Solve)
@@ -23,3 +23,19 @@ def solve(a, b):
2323
return mx.linalg.solve(a, b, stream=mx.cpu).astype(out_dtype, stream=mx.cpu)
2424

2525
return solve
26+
27+
28+
@mlx_funcify.register(SolveTriangular)
29+
def mlx_funcify_SolveTriangular(op, node, **kwargs):
30+
lower = op.lower
31+
out_dtype = getattr(mx, node.outputs[0].dtype)
32+
33+
def solve_triangular(A, b):
34+
return mx.linalg.solve_triangular(
35+
A,
36+
b,
37+
upper=not lower,
38+
stream=mx.cpu, # MLX only supports solve_triangular on CPU
39+
).astype(out_dtype, stream=mx.cpu)
40+
41+
return solve_triangular

tests/link/mlx/test_slinalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,23 @@ def test_mlx_solve(assume_a):
5555

5656
with context:
5757
compare_mlx_and_py([A, b], [out], [A_val, b_val], mlx_mode=mlx_linalg_mode)
58+
59+
60+
@pytest.mark.parametrize("lower, trans", [(False, False), (True, True)])
61+
def test_mlx_SolveTriangular(lower, trans):
62+
rng = np.random.default_rng()
63+
64+
A = pt.tensor("A", shape=(5, 5))
65+
b = pt.tensor("B", shape=(5, 5))
66+
67+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
68+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
69+
70+
out = pt.linalg.solve_triangular(
71+
A,
72+
b,
73+
trans=0,
74+
lower=lower,
75+
unit_diagonal=False,
76+
)
77+
compare_mlx_and_py([A, b], [out], [A_val, b_val], mlx_mode=mlx_linalg_mode)

0 commit comments

Comments
 (0)