Skip to content

Commit d26ab32

Browse files
Implement KroneckerProduct dispatch in mlx backend
1 parent dbc8c32 commit d26ab32

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

pytensor/link/mlx/dispatch/nlinalg.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4-
from pytensor.tensor.nlinalg import (
5-
SVD,
6-
)
4+
from pytensor.tensor.nlinalg import SVD, KroneckerProduct
75

86

97
@mlx_funcify.register(SVD)
@@ -35,3 +33,15 @@ def svd_full(x):
3533
return svd_full
3634
else:
3735
return svd_S_only
36+
37+
38+
@mlx_funcify.register(KroneckerProduct)
39+
def mlx_funcify_KroneckerProduct(op, node, **kwargs):
40+
otype = node.outputs[0].dtype
41+
mx_otype = getattr(mx, otype)
42+
stream = mx.cpu if otype == "float64" else mx.gpu
43+
44+
def kron(a, b):
45+
return mx.kron(a, b, stream=stream).astype(mx_otype, stream=stream)
46+
47+
return kron

tests/link/mlx/test_nlinalg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,19 @@ def test_mlx_svd(compute_uv):
2525
[A_val],
2626
mlx_mode=mlx_linalg_mode,
2727
)
28+
29+
30+
def test_mlx_kron():
31+
rng = np.random.default_rng()
32+
33+
A = pt.matrix(name="A")
34+
B = pt.matrix(name="B")
35+
A_val, B_val = rng.normal(size=(2, 3, 3)).astype(config.floatX)
36+
out = pt.linalg.kron(A, B)
37+
38+
compare_mlx_and_py(
39+
[A, B],
40+
[out],
41+
[A_val, B_val],
42+
mlx_mode=mlx_linalg_mode,
43+
)

0 commit comments

Comments
 (0)