Skip to content

Commit cdc4cb2

Browse files
committed
fix(exla): triangular_solve with batched matrix input
1 parent ae104a2 commit cdc4cb2

File tree

3 files changed

+51
-55
lines changed

3 files changed

+51
-55
lines changed

exla/lib/exla/defn.ex

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -782,26 +782,27 @@ defmodule EXLA.Defn do
782782
lower = Keyword.fetch!(opts, :lower)
783783
transform = Keyword.fetch!(opts, :transform_a)
784784

785-
case Value.get_typespec(b).shape do
786-
{dim} ->
787-
b_shape = {dim, 1}
785+
a_shape = Value.get_typespec(a).shape
786+
b_shape = Value.get_typespec(b).shape
788787

789-
b =
790-
b
791-
|> to_type(type)
792-
|> Value.reshape(Typespec.tensor(type, b_shape))
788+
if tuple_size(a_shape) > tuple_size(b_shape) do
789+
b_shape = Tuple.insert_at(b_shape, tuple_size(b_shape), 1)
793790

794-
typespec = Typespec.tensor(type, b_shape)
791+
b =
792+
b
793+
|> to_type(type)
794+
|> Value.reshape(Typespec.tensor(type, b_shape))
795795

796-
to_type(a, type)
797-
|> Value.triangular_solve(b, left_side, lower, transform, typespec)
798-
|> Value.reshape(Typespec.tensor(type, ans.shape))
796+
typespec = Typespec.tensor(type, b_shape)
799797

800-
_ ->
801-
typespec = Typespec.tensor(type, ans.shape)
798+
to_type(a, type)
799+
|> Value.triangular_solve(b, left_side, lower, transform, typespec)
800+
|> Value.reshape(Typespec.tensor(type, ans.shape))
801+
else
802+
typespec = Typespec.tensor(type, ans.shape)
802803

803-
to_type(a, type)
804-
|> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec)
804+
to_type(a, type)
805+
|> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec)
805806
end
806807
end
807808

exla/lib/exla/mlir/value.ex

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -598,18 +598,7 @@ defmodule EXLA.MLIR.Value do
598598
transform,
599599
typespec
600600
) do
601-
602-
new_b_shape = {2, 3, 1}
603-
new_b_typespec = %{typespec | shape: new_b_shape}
604-
b = reshape(b, new_b_typespec)
605-
606-
expected_output_shape = {2, 3, 1}
607-
new_typespec = %{typespec | shape: expected_output_shape}
608-
result_types = typespecs_to_mlir_types([new_typespec])
609-
610-
611-
612-
# result_types = typespecs_to_mlir_types([typespec])
601+
result_types = typespecs_to_mlir_types([typespec])
613602

614603
complex? = Nx.Type.complex?(typespec.type)
615604

@@ -627,13 +616,8 @@ defmodule EXLA.MLIR.Value do
627616
transpose_a: transpose_a
628617
]
629618

630-
631-
632619
op(func, "stablehlo.triangular_solve", [a, b], result_types, attributes: attributes)
633620
|> one!()
634-
|> reshape(%{typespec | shape: {2, 3}})
635-
636-
637621
end
638622

639623
def dynamic_update_slice(%Value{function: func} = operand, updates, starts, typespec) do

exla/test/exla/nx_linalg_doctest_test.exs

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ defmodule EXLA.NxLinAlgDoctestTest do
2020
least_squares: 3,
2121
determinant: 1,
2222
matrix_power: 2,
23-
lu: 2
23+
lu: 2,
24+
qr: 2
2425
]
2526

2627
@excluded_doctests @function_clause_error_doctests ++
@@ -430,7 +431,7 @@ defmodule EXLA.NxLinAlgDoctestTest do
430431
[6, 10, 1]
431432
])
432433

433-
assert Nx.dot(a, [2], [0], Nx.LinAlg.triangular_solve(a, b), [1], [0]) == b
434+
assert_equal(Nx.dot(a, [2], [0], Nx.LinAlg.triangular_solve(a, b), [1], [0]), b)
434435
end
435436

436437
test "works with B that has more columns than rows" do
@@ -454,36 +455,46 @@ defmodule EXLA.NxLinAlgDoctestTest do
454455

455456
x = Nx.LinAlg.triangular_solve(a, b)
456457

457-
assert x ==
458-
Nx.tensor(
459-
[
460-
[1, 1, 1],
461-
[1, 1, 1]
462-
],
463-
type: :f64
464-
)
458+
assert_equal(
459+
x,
460+
Nx.tensor(
461+
[
462+
[1, 1, 1],
463+
[1, 1, 1]
464+
],
465+
type: :f64
466+
)
467+
)
465468
end
466469

467470
test "property" do
468471
a = Nx.tensor([[1, 0, 0], [1, 1, 0], [0, 1, 1]])
469472
b = Nx.tensor([[1.0, 2.0, 3.0], [2.0, 2.0, 4.0], [2.0, 0.0, 1.0]])
470-
assert Nx.dot(a, Nx.LinAlg.triangular_solve(a, b)) == b
473+
assert_equal(Nx.dot(a, Nx.LinAlg.triangular_solve(a, b)), b)
471474

472475
upper = Nx.transpose(a)
473-
assert Nx.dot(upper, Nx.LinAlg.triangular_solve(upper, b, lower: false)) == b
474-
475-
assert Nx.dot(
476-
Nx.LinAlg.triangular_solve(upper, b, left_side: false, lower: false),
477-
upper
478-
) == b
476+
assert_equal(Nx.dot(upper, Nx.LinAlg.triangular_solve(upper, b, lower: false)), b)
477+
478+
assert_equal(
479+
Nx.dot(
480+
Nx.LinAlg.triangular_solve(upper, b, left_side: false, lower: false),
481+
upper
482+
),
483+
b
484+
)
479485

480-
assert Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose) ==
481-
Nx.LinAlg.triangular_solve(upper, b, lower: false)
486+
assert_equal(
487+
Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose),
488+
Nx.LinAlg.triangular_solve(upper, b, lower: false)
489+
)
482490

483-
assert Nx.dot(
484-
Nx.transpose(a),
485-
Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose)
486-
) == b
491+
assert_equal(
492+
Nx.dot(
493+
Nx.transpose(a),
494+
Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose)
495+
),
496+
b
497+
)
487498
end
488499
end
489500
end

0 commit comments

Comments
 (0)