Skip to content

Commit ad80ef9

Browse files
authored
Fix(exla): triangular_solve with batched matrix input (#1596)
1 parent 73f799f commit ad80ef9

File tree

2 files changed

+112
-16
lines changed

2 files changed

+112
-16
lines changed

exla/lib/exla/defn.ex

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -782,26 +782,27 @@ defmodule EXLA.Defn do
782782
lower = Keyword.fetch!(opts, :lower)
783783
transform = Keyword.fetch!(opts, :transform_a)
784784

785-
case Value.get_typespec(b).shape do
786-
{dim} ->
787-
b_shape = {dim, 1}
785+
a_shape = Value.get_typespec(a).shape
786+
b_shape = Value.get_typespec(b).shape
788787

789-
b =
790-
b
791-
|> to_type(type)
792-
|> Value.reshape(Typespec.tensor(type, b_shape))
788+
if tuple_size(a_shape) > tuple_size(b_shape) do
789+
b_shape = Tuple.insert_at(b_shape, tuple_size(b_shape), 1)
793790

794-
typespec = Typespec.tensor(type, b_shape)
791+
b =
792+
b
793+
|> to_type(type)
794+
|> Value.reshape(Typespec.tensor(type, b_shape))
795795

796-
to_type(a, type)
797-
|> Value.triangular_solve(b, left_side, lower, transform, typespec)
798-
|> Value.reshape(Typespec.tensor(type, ans.shape))
796+
typespec = Typespec.tensor(type, b_shape)
799797

800-
_ ->
801-
typespec = Typespec.tensor(type, ans.shape)
798+
to_type(a, type)
799+
|> Value.triangular_solve(b, left_side, lower, transform, typespec)
800+
|> Value.reshape(Typespec.tensor(type, ans.shape))
801+
else
802+
typespec = Typespec.tensor(type, ans.shape)
802803

803-
to_type(a, type)
804-
|> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec)
804+
to_type(a, type)
805+
|> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec)
805806
end
806807
end
807808

exla/test/exla/nx_linalg_doctest_test.exs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ defmodule EXLA.NxLinAlgDoctestTest do
22
use EXLA.Case, async: true
33
import Nx, only: :sigils
44

5+
setup do
6+
Nx.default_backend(EXLA.Backend)
7+
:ok
8+
end
9+
510
@function_clause_error_doctests [
611
solve: 2,
712
triangular_solve: 3
@@ -15,7 +20,8 @@ defmodule EXLA.NxLinAlgDoctestTest do
1520
least_squares: 3,
1621
determinant: 1,
1722
matrix_power: 2,
18-
lu: 2
23+
lu: 2,
24+
qr: 2
1925
]
2026

2127
@excluded_doctests @function_clause_error_doctests ++
@@ -402,4 +408,93 @@ defmodule EXLA.NxLinAlgDoctestTest do
402408
end
403409
end
404410
end
411+
412+
describe "triangular_solve" do
413+
test "works with batched input" do
414+
a =
415+
Nx.tensor([
416+
[
417+
[-1, 0, 0],
418+
[1, 1, 0],
419+
[1, 1, 1]
420+
],
421+
[
422+
[2, 0, 0],
423+
[4, -2, 0],
424+
[-5, 1, 3]
425+
]
426+
])
427+
428+
b =
429+
Nx.tensor([
430+
[1.0, 2.0, 3.0],
431+
[6, 10, 1]
432+
])
433+
434+
assert_equal(Nx.dot(a, [2], [0], Nx.LinAlg.triangular_solve(a, b), [1], [0]), b)
435+
end
436+
437+
test "works with B that has more columns than rows" do
438+
a =
439+
Nx.tensor(
440+
[
441+
[1, 0],
442+
[1, 1]
443+
],
444+
type: :f64
445+
)
446+
447+
b =
448+
Nx.tensor(
449+
[
450+
[1, 1, 1],
451+
[2, 2, 2]
452+
],
453+
type: :f64
454+
)
455+
456+
x = Nx.LinAlg.triangular_solve(a, b)
457+
458+
assert_equal(
459+
x,
460+
Nx.tensor(
461+
[
462+
[1, 1, 1],
463+
[1, 1, 1]
464+
],
465+
type: :f64
466+
)
467+
)
468+
end
469+
470+
test "property" do
471+
a = Nx.tensor([[1, 0, 0], [1, 1, 0], [0, 1, 1]])
472+
b = Nx.tensor([[1.0, 2.0, 3.0], [2.0, 2.0, 4.0], [2.0, 0.0, 1.0]])
473+
assert_equal(Nx.dot(a, Nx.LinAlg.triangular_solve(a, b)), b)
474+
475+
upper = Nx.transpose(a)
476+
assert_equal(Nx.dot(upper, Nx.LinAlg.triangular_solve(upper, b, lower: false)), b)
477+
478+
assert_equal(
479+
Nx.dot(
480+
Nx.LinAlg.triangular_solve(upper, b, left_side: false, lower: false),
481+
upper
482+
),
483+
b
484+
)
485+
486+
assert_equal(
487+
Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose),
488+
Nx.LinAlg.triangular_solve(upper, b, lower: false)
489+
)
490+
491+
assert_equal(
492+
Nx.dot(
493+
Nx.transpose(a),
494+
Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose)
495+
),
496+
b
497+
)
498+
end
499+
end
405500
end

0 commit comments

Comments
 (0)