@@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do
116116
117117 defp do_ts ( [ ] , [ ] , _idx , acc ) , do: acc
118118
119- defp qr_decomposition ( matrix , n , _eps ) when n in 0 .. 1 do
120- { [ [ 1.0 ] ] , matrix }
121- end
122-
123- defp qr_decomposition ( matrix , n , eps ) when n >= 2 do
124- # QR decomposition is performed by using Householder transform
125- # this function originally supported generic QR, but
126- # it is now only used by eigh. Because of this,
127- # we simplified the function signature to only
128- # support square matrices.
129-
130- { q_matrix , r_matrix } =
131- for i <- 0 .. ( n - 2 ) // 1 , reduce: { nil , matrix } do
132- { q , r } ->
133- h =
134- r
135- |> slice_matrix ( [ i , i ] , [ n - i , 1 ] )
136- |> householder_reflector ( n , eps )
137-
138- # If we haven't allocated Q yet, let Q = H1
139- # TODO: Resolve inconsistent with the Householder reflector.
140- # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
141- q =
142- if is_nil ( q ) do
143- h
144- else
145- dot_matrix_real ( q , h )
146- end
147-
148- r = dot_matrix_real ( h , r )
149- { q , r }
150- end
151-
152- { approximate_zeros ( q_matrix , eps ) , approximate_zeros ( r_matrix , eps ) }
153- end
154-
155- defp raise_not_hermitian do
156- raise ArgumentError ,
157- "matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
158- end
159-
160- def eigh ( input_data , input_type , { n , n } = input_shape , output_type , opts ) do
161- eps = opts [ :eps ]
162- max_iter = opts [ :max_iter ]
163-
164- # Validate that the input is a Hermitian matrix using the relation A^* = A.
165- a = binary_to_matrix ( input_data , input_type , input_shape )
166-
167- is_hermitian =
168- a
169- |> transpose_matrix ( )
170- |> Enum . map ( fn a_row -> Enum . map ( a_row , & Complex . conjugate ( & 1 ) ) end )
171- |> is_approximately_same? ( a , eps )
172-
173- unless is_hermitian do
174- raise_not_hermitian ( )
175- end
176-
177- # Hessenberg decomposition
178- { h , q_h } = hessenberg_decomposition ( a , n , eps )
179-
180- # QR iteration for eigenvalues and eigenvectors
181- { eigenvals_diag , eigenvecs } =
182- Enum . reduce_while ( 1 .. max_iter // 1 , { h , q_h } , fn _ , { a_old , q_old } ->
183- # QR decomposition
184- { q_now , r_now } = qr_decomposition ( a_old , n , eps )
185-
186- # Update matrix A, Q
187- a_new = dot_matrix_real ( r_now , q_now )
188- q_new = dot_matrix_real ( q_old , q_now )
189-
190- if is_approximately_same? ( q_old , q_new , eps ) do
191- { :halt , { a_new , q_new } }
192- else
193- { :cont , { a_new , q_new } }
194- end
195- end )
196-
197- # Obtain the eigenvalues, which are the diagonal elements
198- indices_diag = for idx <- 0 .. ( n - 1 ) , do: [ idx , idx ]
199- eigenvals = get_matrix_elements ( eigenvals_diag , indices_diag )
200-
201- # In general, the eigenvalues of a Hermitian matrix are real numbers
202- eigenvals_real = eigenvals |> Enum . map ( & Complex . real ( & 1 ) )
203-
204- # Reduce the elements smaller than eps to zero
205- { eigenvals_real |> approximate_zeros ( eps ) |> matrix_to_binary ( output_type ) ,
206- eigenvecs |> approximate_zeros ( eps ) |> matrix_to_binary ( output_type ) }
207- end
208-
209- defp hessenberg_decomposition ( matrix , n , _eps ) when n in 0 .. 1 do
210- { matrix , [ [ 1.0 ] ] }
211- end
212-
213- defp hessenberg_decomposition ( matrix , n , eps ) do
214- # Hessenberg decomposition is performed by using Householder transform
215- { hess_matrix , q_matrix } =
216- for i <- 0 .. ( n - 2 ) // 1 , reduce: { matrix , nil } do
217- { hess , q } ->
218- h =
219- hess
220- |> slice_matrix ( [ i + 1 , i ] , [ n - i - 1 , 1 ] )
221- |> householder_reflector ( n , eps )
222-
223- # If we haven't allocated Q yet, let Q = H1
224- # TODO: Resolve inconsistent with the Householder reflector.
225- # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
226- q =
227- if is_nil ( q ) do
228- h
229- else
230- dot_matrix_real ( q , h )
231- end
232-
233- # Hessenberg matrix H updating
234- h_adj = adjoint_matrix ( h )
235-
236- hess =
237- h
238- |> dot_matrix_real ( hess )
239- |> dot_matrix_real ( h_adj )
240-
241- { hess , q }
242- end
243-
244- { approximate_zeros ( hess_matrix , eps ) , approximate_zeros ( q_matrix , eps ) }
245- end
246-
247- defp is_approximately_same? ( a , b , eps ) do
248- # Determine if matrices `a` and `b` are equal in the range of eps
249- a
250- |> Enum . zip ( b )
251- |> Enum . all? ( fn { a_row , b_row } ->
252- a_row
253- |> Enum . zip ( b_row )
254- |> Enum . all? ( fn
255- { a_elem , b_elem } ->
256- abs_diff = Complex . abs ( a_elem - b_elem )
257-
258- abs_diff == :nan or abs_diff <= eps
259- end )
260- end )
261- end
262-
263119 def lu ( input_data , input_type , { n , n } = input_shape , p_type , l_type , u_type , opts ) do
264120 a = binary_to_matrix ( input_data , input_type , input_shape )
265121 eps = opts [ :eps ]
@@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do
361217 end )
362218 end
363219
364- ## Householder helpers
365-
366- defp householder_reflector ( a , target_k , eps )
367-
368- defp householder_reflector ( [ ] , target_k , _eps ) do
369- flat_list =
370- for col <- 0 .. ( target_k - 1 ) , row <- 0 .. ( target_k - 1 ) , into: [ ] do
371- if col == row , do: 1 , else: 0
372- end
373-
374- Enum . chunk_every ( flat_list , target_k )
375- end
376-
377- defp householder_reflector ( a , target_k , eps ) do
378- { v , scale , is_complex } = householder_reflector_pivot ( a , eps )
379-
380- prefix_threshold = target_k - length ( v )
381- v = List . duplicate ( 0 , prefix_threshold ) ++ v
382-
383- # dot(v, v) = norm_v_squared, which can be calculated from norm_a as:
384- # norm_v_squared = norm_a_squared - a_0^2 + v_0^2
385-
386- # execute I - 2 / norm_v_squared * outer(v, v)
387- { _ , _ , reflector_reversed } =
388- for col_factor <- v , row_factor <- v , reduce: { 0 , 0 , [ ] } do
389- { row , col , acc } ->
390- row_factor = if is_complex , do: Complex . conjugate ( row_factor ) , else: row_factor
391-
392- # The current element in outer(v, v) is given by col_factor * row_factor
393- # and the current I element is 1 when row == col
394- identity_element = if row == col , do: 1 , else: 0
395-
396- result =
397- if row >= prefix_threshold and col >= prefix_threshold do
398- identity_element -
399- scale * col_factor * row_factor
400- else
401- identity_element
402- end
403-
404- acc = [ result | acc ]
405-
406- if col + 1 == target_k do
407- { row + 1 , 0 , acc }
408- else
409- { row , col + 1 , acc }
410- end
411- end
412-
413- # This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k)
414- { reflector , _ , _ } =
415- for x <- reflector_reversed , reduce: { [ ] , [ ] , 0 } do
416- { result_acc , row_acc , col } ->
417- row_acc = [ x | row_acc ]
418-
419- if col + 1 == target_k do
420- { [ row_acc | result_acc ] , [ ] , 0 }
421- else
422- { result_acc , row_acc , col + 1 }
423- end
424- end
425-
426- reflector
427- end
428-
429- defp householder_reflector_pivot ( [ a_0 | tail ] = a , eps ) when is_number ( a_0 ) do
430- # This is a trick so we can both calculate the norm of a_reverse and extract the
431- # head a the same time we reverse the array
432- # receives a_reverse as a list of numbers and returns the reflector as a
433- # k x k matrix
434-
435- norm_a_squared = Enum . reduce ( a , 0 , fn x , acc -> x * Complex . conjugate ( x ) + acc end )
436- norm_a_sq_1on = norm_a_squared - a_0 * a_0
437-
438- if norm_a_sq_1on < eps do
439- { [ 1 | tail ] , 0 , false }
440- else
441- v_0 =
442- if a_0 <= 0 do
443- a_0 - Complex . sqrt ( norm_a_squared )
444- else
445- - norm_a_sq_1on / ( a_0 + Complex . sqrt ( norm_a_squared ) )
446- end
447-
448- v_0_sq = v_0 * v_0
449- scale = 2 * v_0_sq / ( norm_a_sq_1on + v_0_sq )
450- v = [ 1 | Enum . map ( tail , & ( & 1 / v_0 ) ) ]
451- { v , scale , false }
452- end
453- end
454-
455- defp householder_reflector_pivot ( [ a_0 | tail ] , _eps ) do
456- # complex case
457- norm_a_sq_1on = Enum . reduce ( tail , 0 , & ( Complex . abs_squared ( & 1 ) + & 2 ) )
458- norm_a_sq = norm_a_sq_1on + Complex . abs_squared ( a_0 )
459- norm_a = Complex . sqrt ( norm_a_sq )
460-
461- phase_a_0 = Complex . phase ( a_0 )
462- alfa = Complex . exp ( Complex . new ( 0 , phase_a_0 ) ) * norm_a
463-
464- # u = x - alfa * e1
465- u_0 = a_0 + alfa
466- u = [ u_0 | tail ]
467- norm_u_sq = norm_a_sq_1on + Complex . abs_squared ( u_0 )
468- norm_u = Complex . sqrt ( norm_u_sq )
469-
470- v = Enum . map ( u , & ( & 1 / norm_u ) )
471- { v , 2 , true }
472- end
473-
474220 ## Matrix (2-D array) manipulation
475221
476222 defp dot_matrix ( [ ] , _ ) , do: 0
@@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do
491237 end )
492238 end
493239
494- defp dot_matrix_real ( m1 , m2 ) do
495- Enum . map ( m1 , fn row ->
496- m2
497- |> transpose_matrix ( )
498- |> Enum . map ( fn col ->
499- Enum . zip_reduce ( row , col , 0 , fn x , y , acc -> acc + x * y end )
500- end )
501- end )
502- end
503-
504- defp adjoint_matrix ( [ x | _ ] = m ) when not is_list ( x ) do
505- Enum . map ( m , & [ Complex . conjugate ( & 1 ) ] )
506- end
507-
508- defp adjoint_matrix ( m ) do
509- Enum . zip_with ( m , fn cols -> Enum . map ( cols , & Complex . conjugate / 1 ) end )
510- end
511-
512240 defp transpose_matrix ( [ x | _ ] = m ) when not is_list ( x ) do
513241 Enum . map ( m , & [ & 1 ] )
514242 end
0 commit comments