Skip to content

Commit 592e0a1

Browse files
authored
test(exla): add more tests for LinAlg functions (#1594)
1 parent cf35078 commit 592e0a1

File tree

1 file changed

+320
-9
lines changed

1 file changed

+320
-9
lines changed

exla/test/exla/nx_linalg_doctest_test.exs

Lines changed: 320 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
defmodule EXLA.MLIR.NxLinAlgDoctestTest do
1+
defmodule EXLA.NxLinAlgDoctestTest do
22
use EXLA.Case, async: true
3-
4-
@invalid_type_error_doctests [
5-
svd: 2,
6-
pinv: 2
7-
]
3+
import Nx, only: :sigils
84

95
@function_clause_error_doctests [
10-
solve: 2
6+
solve: 2,
7+
triangular_solve: 3
118
]
129

1310
@rounding_error_doctests [
14-
triangular_solve: 3,
11+
svd: 2,
12+
pinv: 2,
1513
eigh: 2,
1614
cholesky: 1,
1715
least_squares: 3,
@@ -22,7 +20,6 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
2220

2321
@excluded_doctests @function_clause_error_doctests ++
2422
@rounding_error_doctests ++
25-
@invalid_type_error_doctests ++
2623
[:moduledoc]
2724
doctest Nx.LinAlg, except: @excluded_doctests
2825

@@ -91,4 +88,318 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
9188
end
9289
end
9390
end
91+
92+
describe "cholesky" do
93+
test "property" do
94+
key = Nx.Random.key(System.unique_integer())
95+
96+
for _ <- 1..10, type <- [{:f, 32}, {:c, 64}], reduce: key do
97+
key ->
98+
# Generate random L matrix so we can construct
99+
# a factorizable A matrix:
100+
shape = {3, 4, 4}
101+
102+
{a_prime, key} = Nx.Random.normal(key, 0, 1, shape: shape, type: type)
103+
104+
a_prime = Nx.add(a_prime, Nx.eye(shape))
105+
b = Nx.dot(Nx.LinAlg.adjoint(a_prime), [-1], [0], a_prime, [-2], [0])
106+
107+
d = Nx.eye(shape) |> Nx.multiply(0.1)
108+
109+
a = Nx.add(b, d)
110+
111+
assert l = Nx.LinAlg.cholesky(a)
112+
assert_all_close(Nx.dot(l, [2], [0], Nx.LinAlg.adjoint(l), [1], [0]), a, atol: 1.0e-2)
113+
key
114+
end
115+
end
116+
end
117+
118+
describe "least_squares" do
119+
test "properties for linear equations" do
120+
key = Nx.Random.key(System.unique_integer())
121+
122+
# Calucate linear equations Ax = y by using least-squares solution
123+
for {m, n} <- [{2, 2}, {3, 2}, {4, 3}], reduce: key do
124+
key ->
125+
# Generate x as temporary solution and A as base matrix
126+
{a_base, key} = Nx.Random.randint(key, 1, 10, shape: {m, n})
127+
{x_temp, key} = Nx.Random.randint(key, 1, 10, shape: {n})
128+
129+
# Generate y as base vector by x and A
130+
# to prepare an equation that can be solved exactly
131+
y_base = Nx.dot(a_base, x_temp)
132+
133+
# Generate y as random noise vector and A as random noise matrix
134+
noise_eps = 1.0e-2
135+
{a_noise, key} = Nx.Random.uniform(key, 0, noise_eps, shape: {m, n})
136+
{y_noise, key} = Nx.Random.uniform(key, 0, noise_eps, shape: {m})
137+
138+
# Add noise to prepare equations that cannot be solved without approximation,
139+
# such as the least-squares method
140+
a = Nx.add(a_base, a_noise)
141+
y = Nx.add(y_base, y_noise)
142+
143+
# Calculate least-squares solution to a linear matrix equation Ax = y
144+
x = Nx.LinAlg.least_squares(a, y)
145+
146+
# Check linear matrix equation
147+
148+
assert_all_close(Nx.dot(a, x), y, atol: noise_eps * 10)
149+
150+
key
151+
end
152+
end
153+
end
154+
155+
describe "determinant" do
156+
test "supports batched matrices" do
157+
two_by_two = Nx.tensor([[[2, 3], [4, 5]], [[6, 3], [4, 8]]])
158+
assert_equal(Nx.LinAlg.determinant(two_by_two), Nx.tensor([-2.0, 36.0]))
159+
160+
three_by_three =
161+
Nx.tensor([
162+
[[1.0, 2.0, 3.0], [1.0, 5.0, 3.0], [7.0, 6.0, 9.0]],
163+
[[5.0, 2.0, 3.0], [8.0, 5.0, 4.0], [3.0, 1.0, -9.0]]
164+
])
165+
166+
assert_equal(Nx.LinAlg.determinant(three_by_three), Nx.tensor([-36.0, -98.0]))
167+
168+
four_by_four =
169+
Nx.tensor([
170+
[
171+
[1.0, 2.0, 3.0, 0.0],
172+
[1.0, 5.0, 3.0, 0.0],
173+
[7.0, 6.0, 9.0, 0.0],
174+
[0.0, -11.0, 2.0, 3.0]
175+
],
176+
[
177+
[5.0, 2.0, 3.0, 0.0],
178+
[8.0, 5.0, 4.0, 0.0],
179+
[3.0, 1.0, -9.0, 0.0],
180+
[8.0, 2.0, -4.0, 5.0]
181+
]
182+
])
183+
184+
assert_all_close(Nx.LinAlg.determinant(four_by_four), Nx.tensor([-108.0, -490]))
185+
end
186+
end
187+
188+
describe "matrix_power" do
189+
test "supports complex with positive exponent" do
190+
a = ~MAT[
191+
1 1i
192+
-1i 1
193+
]
194+
195+
n = 5
196+
197+
assert_all_close(Nx.LinAlg.matrix_power(a, n), Nx.multiply(2 ** (n - 1), a))
198+
end
199+
200+
test "supports complex with 0 exponent" do
201+
a = ~MAT[
202+
1 1i
203+
-1i 1
204+
]
205+
206+
assert_all_close(Nx.LinAlg.matrix_power(a, 0), Nx.eye(Nx.shape(a)))
207+
end
208+
209+
test "supports complex with negative exponent" do
210+
a = ~MAT[
211+
1 -0.5i
212+
0 0.5
213+
]
214+
215+
result = ~MAT[
216+
1 15i
217+
0 16
218+
]
219+
220+
assert_all_close(Nx.LinAlg.matrix_power(a, -4), result)
221+
end
222+
223+
test "supports batched matrices" do
224+
a =
225+
Nx.tensor([
226+
[[5, 3], [1, 2]],
227+
[[9, 0], [4, 7]]
228+
])
229+
230+
result =
231+
Nx.tensor([
232+
[[161, 126], [42, 35]],
233+
[[729, 0], [772, 343]]
234+
])
235+
236+
assert_all_close(Nx.LinAlg.matrix_power(a, 3), result)
237+
end
238+
end
239+
240+
describe "lu" do
241+
test "property" do
242+
key = Nx.Random.key(System.unique_integer())
243+
244+
for _ <- 1..10, type <- [{:f, 32}, {:c, 64}], reduce: key do
245+
key ->
246+
# Generate random L and U matrices so we can construct
247+
# a factorizable A matrix:
248+
shape = {3, 4, 4}
249+
lower_selector = Nx.iota(shape, axis: 1) |> Nx.greater_equal(Nx.iota(shape, axis: 2))
250+
upper_selector = Nx.LinAlg.adjoint(lower_selector)
251+
252+
{l_prime, key} = Nx.Random.uniform(key, 0, 1, shape: shape, type: type)
253+
l_prime = Nx.multiply(l_prime, lower_selector)
254+
255+
{u_prime, key} = Nx.Random.uniform(key, 0, 1, shape: shape, type: type)
256+
u_prime = Nx.multiply(u_prime, upper_selector)
257+
258+
a = Nx.dot(l_prime, [2], [0], u_prime, [1], [0])
259+
260+
assert {p, l, u} = Nx.LinAlg.lu(a)
261+
262+
actual = p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0])
263+
assert_all_close(actual, a)
264+
key
265+
end
266+
end
267+
end
268+
269+
describe "svd" do
270+
test "finds the singular values of tall matrices" do
271+
t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]])
272+
273+
assert {%{type: output_type} = u, %{type: output_type} = s, %{type: output_type} = v} =
274+
Nx.LinAlg.svd(t, max_iter: 1000)
275+
276+
s_matrix = 0 |> Nx.broadcast({4, 3}) |> Nx.put_diagonal(s)
277+
278+
assert_all_close(t, u |> Nx.dot(s_matrix) |> Nx.dot(v), atol: 1.0e-2, rtol: 1.0e-2)
279+
280+
assert_all_close(
281+
u,
282+
Nx.tensor([
283+
[0.140, 0.824, 0.521, -0.166],
284+
[0.343, 0.426, -0.571, 0.611],
285+
[0.547, 0.0278, -0.422, -0.722],
286+
[0.750, -0.370, 0.472, 0.277]
287+
]),
288+
atol: 1.0e-3,
289+
rtol: 1.0e-3
290+
)
291+
292+
assert_all_close(Nx.tensor([25.462, 1.291, 0.0]), s, atol: 1.0e-3, rtol: 1.0e-3)
293+
294+
assert_all_close(
295+
Nx.tensor([
296+
[0.504, 0.574, 0.644],
297+
[-0.760, -0.057, 0.646],
298+
[0.408, -0.816, 0.408]
299+
]),
300+
v,
301+
atol: 1.0e-3,
302+
rtol: 1.0e-3
303+
)
304+
end
305+
306+
test "works with batched matrices" do
307+
t =
308+
Nx.tensor([
309+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
310+
[[1.0, 2.0, 3.0], [0.0, 4.0, 0.0], [0.0, 0.0, 9.0]]
311+
])
312+
313+
assert {u, s, v} = Nx.LinAlg.svd(t)
314+
315+
s_matrix =
316+
Nx.stack([
317+
Nx.broadcast(0, {3, 3}) |> Nx.put_diagonal(s[0]),
318+
Nx.broadcast(0, {3, 3}) |> Nx.put_diagonal(s[1])
319+
])
320+
321+
reconstructed_t =
322+
u
323+
|> Nx.dot([2], [0], s_matrix, [1], [0])
324+
|> Nx.dot([2], [0], v, [1], [0])
325+
326+
assert_all_close(t, reconstructed_t, atol: 1.0e-2, rtol: 1.0e-2)
327+
end
328+
329+
test "works with vectorized tensors matrices" do
330+
t =
331+
Nx.tensor([
332+
[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]],
333+
[[[1.0, 2.0, 3.0], [0.0, 4.0, 0.0], [0.0, 0.0, 9.0]]]
334+
])
335+
|> Nx.vectorize(x: 2, y: 1)
336+
337+
assert {u, s, v} = Nx.LinAlg.svd(t)
338+
339+
s_matrix = Nx.put_diagonal(Nx.broadcast(0, {3, 3}), s)
340+
341+
reconstructed_t =
342+
u
343+
|> Nx.dot(s_matrix)
344+
|> Nx.dot(v)
345+
346+
assert reconstructed_t.vectorized_axes == [x: 2, y: 1]
347+
assert reconstructed_t.shape == {3, 3}
348+
349+
assert_all_close(Nx.devectorize(t), Nx.devectorize(reconstructed_t),
350+
atol: 1.0e-2,
351+
rtol: 1.0e-2
352+
)
353+
end
354+
355+
test "works with vectors" do
356+
t = Nx.tensor([[-2], [1]])
357+
358+
{u, s, vt} = Nx.LinAlg.svd(t)
359+
assert_all_close(u |> Nx.dot(Nx.stack([s, Nx.tensor([0])])) |> Nx.dot(vt), t)
360+
end
361+
362+
test "works with zero-tensor" do
363+
for {m, n, k} <- [{3, 3, 3}, {3, 4, 3}, {4, 3, 3}] do
364+
t = Nx.broadcast(0, {m, n})
365+
{u, s, vt} = Nx.LinAlg.svd(t)
366+
assert_all_close(u, Nx.eye({m, m}))
367+
assert_all_close(s, Nx.broadcast(0, {k}))
368+
assert_all_close(vt, Nx.eye({n, n}))
369+
end
370+
end
371+
end
372+
373+
describe "pinv" do
374+
test "does not raise for 0 singular values" do
375+
key = Nx.Random.key(System.unique_integer())
376+
377+
for {m, n} <- [{3, 4}, {3, 3}, {4, 3}], reduce: key do
378+
key ->
379+
# generate u and vt as random orthonormal matrices
380+
{base_u, key} = Nx.Random.uniform(key, 0, 1, shape: {m, m}, type: :f64)
381+
{u, _} = Nx.LinAlg.qr(base_u)
382+
{base_vt, key} = Nx.Random.uniform(key, 0, 1, shape: {n, n}, type: :f64)
383+
{vt, _} = Nx.LinAlg.qr(base_vt)
384+
385+
# because min(m, n) is always 3, we can use fixed values here
386+
# the important thing is that there's at least one zero in the
387+
# diagonal, to ensure that we're guarding against 0 division
388+
zeros = Nx.broadcast(0, {m, n})
389+
s = Nx.put_diagonal(zeros, Nx.f64([1, 4, 0]))
390+
s_inv = Nx.put_diagonal(Nx.transpose(zeros), Nx.f64([1, 0.25, 0]))
391+
392+
# construct t with the given singular values
393+
t = u |> Nx.dot(s) |> Nx.dot(vt)
394+
pinv = Nx.LinAlg.pinv(t)
395+
396+
# ensure that the returned pinv is close to what we expect
397+
assert_all_close(pinv, Nx.transpose(vt) |> Nx.dot(s_inv) |> Nx.dot(Nx.transpose(u)),
398+
atol: 1.0e-2
399+
)
400+
401+
key
402+
end
403+
end
404+
end
94405
end

0 commit comments

Comments
 (0)