1- defmodule EXLA.MLIR. NxLinAlgDoctestTest do
1+ defmodule EXLA.NxLinAlgDoctestTest do
22 use EXLA.Case , async: true
3-
4- @ invalid_type_error_doctests [
5- svd: 2 ,
6- pinv: 2
7- ]
3+ import Nx , only: :sigils
84
95 @ function_clause_error_doctests [
10- solve: 2
6+ solve: 2 ,
7+ triangular_solve: 3
118 ]
129
1310 @ rounding_error_doctests [
14- triangular_solve: 3 ,
11+ svd: 2 ,
12+ pinv: 2 ,
1513 eigh: 2 ,
1614 cholesky: 1 ,
1715 least_squares: 3 ,
@@ -22,7 +20,6 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
2220
2321 @ excluded_doctests @ function_clause_error_doctests ++
2422 @ rounding_error_doctests ++
25- @ invalid_type_error_doctests ++
2623 [ :moduledoc ]
2724 doctest Nx.LinAlg , except: @ excluded_doctests
2825
@@ -91,4 +88,318 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
9188 end
9289 end
9390 end
91+
92+ describe "cholesky" do
93+ test "property" do
94+ key = Nx.Random . key ( System . unique_integer ( ) )
95+
96+ for _ <- 1 .. 10 , type <- [ { :f , 32 } , { :c , 64 } ] , reduce: key do
97+ key ->
98+ # Generate random L matrix so we can construct
99+ # a factorizable A matrix:
100+ shape = { 3 , 4 , 4 }
101+
102+ { a_prime , key } = Nx.Random . normal ( key , 0 , 1 , shape: shape , type: type )
103+
104+ a_prime = Nx . add ( a_prime , Nx . eye ( shape ) )
105+ b = Nx . dot ( Nx.LinAlg . adjoint ( a_prime ) , [ - 1 ] , [ 0 ] , a_prime , [ - 2 ] , [ 0 ] )
106+
107+ d = Nx . eye ( shape ) |> Nx . multiply ( 0.1 )
108+
109+ a = Nx . add ( b , d )
110+
111+ assert l = Nx.LinAlg . cholesky ( a )
112+ assert_all_close ( Nx . dot ( l , [ 2 ] , [ 0 ] , Nx.LinAlg . adjoint ( l ) , [ 1 ] , [ 0 ] ) , a , atol: 1.0e-2 )
113+ key
114+ end
115+ end
116+ end
117+
118+ describe "least_squares" do
119+ test "properties for linear equations" do
120+ key = Nx.Random . key ( System . unique_integer ( ) )
121+
122+ # Calucate linear equations Ax = y by using least-squares solution
123+ for { m , n } <- [ { 2 , 2 } , { 3 , 2 } , { 4 , 3 } ] , reduce: key do
124+ key ->
125+ # Generate x as temporary solution and A as base matrix
126+ { a_base , key } = Nx.Random . randint ( key , 1 , 10 , shape: { m , n } )
127+ { x_temp , key } = Nx.Random . randint ( key , 1 , 10 , shape: { n } )
128+
129+ # Generate y as base vector by x and A
130+ # to prepare an equation that can be solved exactly
131+ y_base = Nx . dot ( a_base , x_temp )
132+
133+ # Generate y as random noise vector and A as random noise matrix
134+ noise_eps = 1.0e-2
135+ { a_noise , key } = Nx.Random . uniform ( key , 0 , noise_eps , shape: { m , n } )
136+ { y_noise , key } = Nx.Random . uniform ( key , 0 , noise_eps , shape: { m } )
137+
138+ # Add noise to prepare equations that cannot be solved without approximation,
139+ # such as the least-squares method
140+ a = Nx . add ( a_base , a_noise )
141+ y = Nx . add ( y_base , y_noise )
142+
143+ # Calculate least-squares solution to a linear matrix equation Ax = y
144+ x = Nx.LinAlg . least_squares ( a , y )
145+
146+ # Check linear matrix equation
147+
148+ assert_all_close ( Nx . dot ( a , x ) , y , atol: noise_eps * 10 )
149+
150+ key
151+ end
152+ end
153+ end
154+
155+ describe "determinant" do
156+ test "supports batched matrices" do
157+ two_by_two = Nx . tensor ( [ [ [ 2 , 3 ] , [ 4 , 5 ] ] , [ [ 6 , 3 ] , [ 4 , 8 ] ] ] )
158+ assert_equal ( Nx.LinAlg . determinant ( two_by_two ) , Nx . tensor ( [ - 2.0 , 36.0 ] ) )
159+
160+ three_by_three =
161+ Nx . tensor ( [
162+ [ [ 1.0 , 2.0 , 3.0 ] , [ 1.0 , 5.0 , 3.0 ] , [ 7.0 , 6.0 , 9.0 ] ] ,
163+ [ [ 5.0 , 2.0 , 3.0 ] , [ 8.0 , 5.0 , 4.0 ] , [ 3.0 , 1.0 , - 9.0 ] ]
164+ ] )
165+
166+ assert_equal ( Nx.LinAlg . determinant ( three_by_three ) , Nx . tensor ( [ - 36.0 , - 98.0 ] ) )
167+
168+ four_by_four =
169+ Nx . tensor ( [
170+ [
171+ [ 1.0 , 2.0 , 3.0 , 0.0 ] ,
172+ [ 1.0 , 5.0 , 3.0 , 0.0 ] ,
173+ [ 7.0 , 6.0 , 9.0 , 0.0 ] ,
174+ [ 0.0 , - 11.0 , 2.0 , 3.0 ]
175+ ] ,
176+ [
177+ [ 5.0 , 2.0 , 3.0 , 0.0 ] ,
178+ [ 8.0 , 5.0 , 4.0 , 0.0 ] ,
179+ [ 3.0 , 1.0 , - 9.0 , 0.0 ] ,
180+ [ 8.0 , 2.0 , - 4.0 , 5.0 ]
181+ ]
182+ ] )
183+
184+ assert_all_close ( Nx.LinAlg . determinant ( four_by_four ) , Nx . tensor ( [ - 108.0 , - 490 ] ) )
185+ end
186+ end
187+
188+ describe "matrix_power" do
189+ test "supports complex with positive exponent" do
190+ a = ~MAT[
191+ 1 1i
192+ -1i 1
193+ ]
194+
195+ n = 5
196+
197+ assert_all_close ( Nx.LinAlg . matrix_power ( a , n ) , Nx . multiply ( 2 ** ( n - 1 ) , a ) )
198+ end
199+
200+ test "supports complex with 0 exponent" do
201+ a = ~MAT[
202+ 1 1i
203+ -1i 1
204+ ]
205+
206+ assert_all_close ( Nx.LinAlg . matrix_power ( a , 0 ) , Nx . eye ( Nx . shape ( a ) ) )
207+ end
208+
209+ test "supports complex with negative exponent" do
210+ a = ~MAT[
211+ 1 -0.5i
212+ 0 0.5
213+ ]
214+
215+ result = ~MAT[
216+ 1 15i
217+ 0 16
218+ ]
219+
220+ assert_all_close ( Nx.LinAlg . matrix_power ( a , - 4 ) , result )
221+ end
222+
223+ test "supports batched matrices" do
224+ a =
225+ Nx . tensor ( [
226+ [ [ 5 , 3 ] , [ 1 , 2 ] ] ,
227+ [ [ 9 , 0 ] , [ 4 , 7 ] ]
228+ ] )
229+
230+ result =
231+ Nx . tensor ( [
232+ [ [ 161 , 126 ] , [ 42 , 35 ] ] ,
233+ [ [ 729 , 0 ] , [ 772 , 343 ] ]
234+ ] )
235+
236+ assert_all_close ( Nx.LinAlg . matrix_power ( a , 3 ) , result )
237+ end
238+ end
239+
240+ describe "lu" do
241+ test "property" do
242+ key = Nx.Random . key ( System . unique_integer ( ) )
243+
244+ for _ <- 1 .. 10 , type <- [ { :f , 32 } , { :c , 64 } ] , reduce: key do
245+ key ->
246+ # Generate random L and U matrices so we can construct
247+ # a factorizable A matrix:
248+ shape = { 3 , 4 , 4 }
249+ lower_selector = Nx . iota ( shape , axis: 1 ) |> Nx . greater_equal ( Nx . iota ( shape , axis: 2 ) )
250+ upper_selector = Nx.LinAlg . adjoint ( lower_selector )
251+
252+ { l_prime , key } = Nx.Random . uniform ( key , 0 , 1 , shape: shape , type: type )
253+ l_prime = Nx . multiply ( l_prime , lower_selector )
254+
255+ { u_prime , key } = Nx.Random . uniform ( key , 0 , 1 , shape: shape , type: type )
256+ u_prime = Nx . multiply ( u_prime , upper_selector )
257+
258+ a = Nx . dot ( l_prime , [ 2 ] , [ 0 ] , u_prime , [ 1 ] , [ 0 ] )
259+
260+ assert { p , l , u } = Nx.LinAlg . lu ( a )
261+
262+ actual = p |> Nx . dot ( [ 2 ] , [ 0 ] , l , [ 1 ] , [ 0 ] ) |> Nx . dot ( [ 2 ] , [ 0 ] , u , [ 1 ] , [ 0 ] )
263+ assert_all_close ( actual , a )
264+ key
265+ end
266+ end
267+ end
268+
269+ describe "svd" do
270+ test "finds the singular values of tall matrices" do
271+ t = Nx . tensor ( [ [ 1.0 , 2.0 , 3.0 ] , [ 4.0 , 5.0 , 6.0 ] , [ 7.0 , 8.0 , 9.0 ] , [ 10.0 , 11.0 , 12.0 ] ] )
272+
273+ assert { % { type: output_type } = u , % { type: output_type } = s , % { type: output_type } = v } =
274+ Nx.LinAlg . svd ( t , max_iter: 1000 )
275+
276+ s_matrix = 0 |> Nx . broadcast ( { 4 , 3 } ) |> Nx . put_diagonal ( s )
277+
278+ assert_all_close ( t , u |> Nx . dot ( s_matrix ) |> Nx . dot ( v ) , atol: 1.0e-2 , rtol: 1.0e-2 )
279+
280+ assert_all_close (
281+ u ,
282+ Nx . tensor ( [
283+ [ 0.140 , 0.824 , 0.521 , - 0.166 ] ,
284+ [ 0.343 , 0.426 , - 0.571 , 0.611 ] ,
285+ [ 0.547 , 0.0278 , - 0.422 , - 0.722 ] ,
286+ [ 0.750 , - 0.370 , 0.472 , 0.277 ]
287+ ] ) ,
288+ atol: 1.0e-3 ,
289+ rtol: 1.0e-3
290+ )
291+
292+ assert_all_close ( Nx . tensor ( [ 25.462 , 1.291 , 0.0 ] ) , s , atol: 1.0e-3 , rtol: 1.0e-3 )
293+
294+ assert_all_close (
295+ Nx . tensor ( [
296+ [ 0.504 , 0.574 , 0.644 ] ,
297+ [ - 0.760 , - 0.057 , 0.646 ] ,
298+ [ 0.408 , - 0.816 , 0.408 ]
299+ ] ) ,
300+ v ,
301+ atol: 1.0e-3 ,
302+ rtol: 1.0e-3
303+ )
304+ end
305+
306+ test "works with batched matrices" do
307+ t =
308+ Nx . tensor ( [
309+ [ [ 1.0 , 2.0 , 3.0 ] , [ 4.0 , 5.0 , 6.0 ] , [ 7.0 , 8.0 , 9.0 ] ] ,
310+ [ [ 1.0 , 2.0 , 3.0 ] , [ 0.0 , 4.0 , 0.0 ] , [ 0.0 , 0.0 , 9.0 ] ]
311+ ] )
312+
313+ assert { u , s , v } = Nx.LinAlg . svd ( t )
314+
315+ s_matrix =
316+ Nx . stack ( [
317+ Nx . broadcast ( 0 , { 3 , 3 } ) |> Nx . put_diagonal ( s [ 0 ] ) ,
318+ Nx . broadcast ( 0 , { 3 , 3 } ) |> Nx . put_diagonal ( s [ 1 ] )
319+ ] )
320+
321+ reconstructed_t =
322+ u
323+ |> Nx . dot ( [ 2 ] , [ 0 ] , s_matrix , [ 1 ] , [ 0 ] )
324+ |> Nx . dot ( [ 2 ] , [ 0 ] , v , [ 1 ] , [ 0 ] )
325+
326+ assert_all_close ( t , reconstructed_t , atol: 1.0e-2 , rtol: 1.0e-2 )
327+ end
328+
329+ test "works with vectorized tensors matrices" do
330+ t =
331+ Nx . tensor ( [
332+ [ [ [ 1.0 , 2.0 , 3.0 ] , [ 4.0 , 5.0 , 6.0 ] , [ 7.0 , 8.0 , 9.0 ] ] ] ,
333+ [ [ [ 1.0 , 2.0 , 3.0 ] , [ 0.0 , 4.0 , 0.0 ] , [ 0.0 , 0.0 , 9.0 ] ] ]
334+ ] )
335+ |> Nx . vectorize ( x: 2 , y: 1 )
336+
337+ assert { u , s , v } = Nx.LinAlg . svd ( t )
338+
339+ s_matrix = Nx . put_diagonal ( Nx . broadcast ( 0 , { 3 , 3 } ) , s )
340+
341+ reconstructed_t =
342+ u
343+ |> Nx . dot ( s_matrix )
344+ |> Nx . dot ( v )
345+
346+ assert reconstructed_t . vectorized_axes == [ x: 2 , y: 1 ]
347+ assert reconstructed_t . shape == { 3 , 3 }
348+
349+ assert_all_close ( Nx . devectorize ( t ) , Nx . devectorize ( reconstructed_t ) ,
350+ atol: 1.0e-2 ,
351+ rtol: 1.0e-2
352+ )
353+ end
354+
355+ test "works with vectors" do
356+ t = Nx . tensor ( [ [ - 2 ] , [ 1 ] ] )
357+
358+ { u , s , vt } = Nx.LinAlg . svd ( t )
359+ assert_all_close ( u |> Nx . dot ( Nx . stack ( [ s , Nx . tensor ( [ 0 ] ) ] ) ) |> Nx . dot ( vt ) , t )
360+ end
361+
362+ test "works with zero-tensor" do
363+ for { m , n , k } <- [ { 3 , 3 , 3 } , { 3 , 4 , 3 } , { 4 , 3 , 3 } ] do
364+ t = Nx . broadcast ( 0 , { m , n } )
365+ { u , s , vt } = Nx.LinAlg . svd ( t )
366+ assert_all_close ( u , Nx . eye ( { m , m } ) )
367+ assert_all_close ( s , Nx . broadcast ( 0 , { k } ) )
368+ assert_all_close ( vt , Nx . eye ( { n , n } ) )
369+ end
370+ end
371+ end
372+
373+ describe "pinv" do
374+ test "does not raise for 0 singular values" do
375+ key = Nx.Random . key ( System . unique_integer ( ) )
376+
377+ for { m , n } <- [ { 3 , 4 } , { 3 , 3 } , { 4 , 3 } ] , reduce: key do
378+ key ->
379+ # generate u and vt as random orthonormal matrices
380+ { base_u , key } = Nx.Random . uniform ( key , 0 , 1 , shape: { m , m } , type: :f64 )
381+ { u , _ } = Nx.LinAlg . qr ( base_u )
382+ { base_vt , key } = Nx.Random . uniform ( key , 0 , 1 , shape: { n , n } , type: :f64 )
383+ { vt , _ } = Nx.LinAlg . qr ( base_vt )
384+
385+ # because min(m, n) is always 3, we can use fixed values here
386+ # the important thing is that there's at least one zero in the
387+ # diagonal, to ensure that we're guarding against 0 division
388+ zeros = Nx . broadcast ( 0 , { m , n } )
389+ s = Nx . put_diagonal ( zeros , Nx . f64 ( [ 1 , 4 , 0 ] ) )
390+ s_inv = Nx . put_diagonal ( Nx . transpose ( zeros ) , Nx . f64 ( [ 1 , 0.25 , 0 ] ) )
391+
392+ # construct t with the given singular values
393+ t = u |> Nx . dot ( s ) |> Nx . dot ( vt )
394+ pinv = Nx.LinAlg . pinv ( t )
395+
396+ # ensure that the returned pinv is close to what we expect
397+ assert_all_close ( pinv , Nx . transpose ( vt ) |> Nx . dot ( s_inv ) |> Nx . dot ( Nx . transpose ( u ) ) ,
398+ atol: 1.0e-2
399+ )
400+
401+ key
402+ end
403+ end
404+ end
94405end
0 commit comments