Skip to content

Commit 8dc7b29

Browse files
Blocked Jacobi method for eigen decomposition (#1510)
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent 7ef59e2 commit 8dc7b29

File tree

8 files changed

+396
-377
lines changed

8 files changed

+396
-377
lines changed

nx/lib/nx/binary_backend.ex

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,25 +1240,6 @@ defmodule Nx.BinaryBackend do
12401240
output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
12411241
end
12421242

1243-
@impl true
1244-
def eigh(
1245-
{%{type: output_type} = eigenvals_holder, eigenvecs_holder},
1246-
%{type: input_type, shape: input_shape} = tensor,
1247-
opts
1248-
) do
1249-
bin = to_binary(tensor)
1250-
rank = tuple_size(input_shape)
1251-
n = elem(input_shape, rank - 1)
1252-
1253-
{eigenvals, eigenvecs} =
1254-
bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} ->
1255-
{vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts)
1256-
{vals_acc <> vals, vecs_acc <> vecs}
1257-
end)
1258-
1259-
{from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)}
1260-
end
1261-
12621243
@impl true
12631244
def lu(
12641245
{%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},

nx/lib/nx/binary_backend/matrix.ex

Lines changed: 0 additions & 272 deletions
Original file line numberDiff line numberDiff line change
@@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do
116116

117117
defp do_ts([], [], _idx, acc), do: acc
118118

119-
defp qr_decomposition(matrix, n, _eps) when n in 0..1 do
120-
{[[1.0]], matrix}
121-
end
122-
123-
defp qr_decomposition(matrix, n, eps) when n >= 2 do
124-
# QR decomposition is performed by using Householder transform
125-
# this function originally supported generic QR, but
126-
# it is now only used by eigh. Because of this,
127-
# we simplified the function signature to only
128-
# support square matrices.
129-
130-
{q_matrix, r_matrix} =
131-
for i <- 0..(n - 2)//1, reduce: {nil, matrix} do
132-
{q, r} ->
133-
h =
134-
r
135-
|> slice_matrix([i, i], [n - i, 1])
136-
|> householder_reflector(n, eps)
137-
138-
# If we haven't allocated Q yet, let Q = H1
139-
# TODO: Resolve inconsistent with the Householder reflector.
140-
# cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
141-
q =
142-
if is_nil(q) do
143-
h
144-
else
145-
dot_matrix_real(q, h)
146-
end
147-
148-
r = dot_matrix_real(h, r)
149-
{q, r}
150-
end
151-
152-
{approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)}
153-
end
154-
155-
defp raise_not_hermitian do
156-
raise ArgumentError,
157-
"matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
158-
end
159-
160-
def eigh(input_data, input_type, {n, n} = input_shape, output_type, opts) do
161-
eps = opts[:eps]
162-
max_iter = opts[:max_iter]
163-
164-
# Validate that the input is a Hermitian matrix using the relation A^* = A.
165-
a = binary_to_matrix(input_data, input_type, input_shape)
166-
167-
is_hermitian =
168-
a
169-
|> transpose_matrix()
170-
|> Enum.map(fn a_row -> Enum.map(a_row, &Complex.conjugate(&1)) end)
171-
|> is_approximately_same?(a, eps)
172-
173-
unless is_hermitian do
174-
raise_not_hermitian()
175-
end
176-
177-
# Hessenberg decomposition
178-
{h, q_h} = hessenberg_decomposition(a, n, eps)
179-
180-
# QR iteration for eigenvalues and eigenvectors
181-
{eigenvals_diag, eigenvecs} =
182-
Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} ->
183-
# QR decomposition
184-
{q_now, r_now} = qr_decomposition(a_old, n, eps)
185-
186-
# Update matrix A, Q
187-
a_new = dot_matrix_real(r_now, q_now)
188-
q_new = dot_matrix_real(q_old, q_now)
189-
190-
if is_approximately_same?(q_old, q_new, eps) do
191-
{:halt, {a_new, q_new}}
192-
else
193-
{:cont, {a_new, q_new}}
194-
end
195-
end)
196-
197-
# Obtain the eigenvalues, which are the diagonal elements
198-
indices_diag = for idx <- 0..(n - 1), do: [idx, idx]
199-
eigenvals = get_matrix_elements(eigenvals_diag, indices_diag)
200-
201-
# In general, the eigenvalues of a Hermitian matrix are real numbers
202-
eigenvals_real = eigenvals |> Enum.map(&Complex.real(&1))
203-
204-
# Reduce the elements smaller than eps to zero
205-
{eigenvals_real |> approximate_zeros(eps) |> matrix_to_binary(output_type),
206-
eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)}
207-
end
208-
209-
defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do
210-
{matrix, [[1.0]]}
211-
end
212-
213-
defp hessenberg_decomposition(matrix, n, eps) do
214-
# Hessenberg decomposition is performed by using Householder transform
215-
{hess_matrix, q_matrix} =
216-
for i <- 0..(n - 2)//1, reduce: {matrix, nil} do
217-
{hess, q} ->
218-
h =
219-
hess
220-
|> slice_matrix([i + 1, i], [n - i - 1, 1])
221-
|> householder_reflector(n, eps)
222-
223-
# If we haven't allocated Q yet, let Q = H1
224-
# TODO: Resolve inconsistent with the Householder reflector.
225-
# cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
226-
q =
227-
if is_nil(q) do
228-
h
229-
else
230-
dot_matrix_real(q, h)
231-
end
232-
233-
# Hessenberg matrix H updating
234-
h_adj = adjoint_matrix(h)
235-
236-
hess =
237-
h
238-
|> dot_matrix_real(hess)
239-
|> dot_matrix_real(h_adj)
240-
241-
{hess, q}
242-
end
243-
244-
{approximate_zeros(hess_matrix, eps), approximate_zeros(q_matrix, eps)}
245-
end
246-
247-
defp is_approximately_same?(a, b, eps) do
248-
# Determine if matrices `a` and `b` are equal in the range of eps
249-
a
250-
|> Enum.zip(b)
251-
|> Enum.all?(fn {a_row, b_row} ->
252-
a_row
253-
|> Enum.zip(b_row)
254-
|> Enum.all?(fn
255-
{a_elem, b_elem} ->
256-
abs_diff = Complex.abs(a_elem - b_elem)
257-
258-
abs_diff == :nan or abs_diff <= eps
259-
end)
260-
end)
261-
end
262-
263119
def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do
264120
a = binary_to_matrix(input_data, input_type, input_shape)
265121
eps = opts[:eps]
@@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do
361217
end)
362218
end
363219

364-
## Householder helpers
365-
366-
defp householder_reflector(a, target_k, eps)
367-
368-
defp householder_reflector([], target_k, _eps) do
369-
flat_list =
370-
for col <- 0..(target_k - 1), row <- 0..(target_k - 1), into: [] do
371-
if col == row, do: 1, else: 0
372-
end
373-
374-
Enum.chunk_every(flat_list, target_k)
375-
end
376-
377-
defp householder_reflector(a, target_k, eps) do
378-
{v, scale, is_complex} = householder_reflector_pivot(a, eps)
379-
380-
prefix_threshold = target_k - length(v)
381-
v = List.duplicate(0, prefix_threshold) ++ v
382-
383-
# dot(v, v) = norm_v_squared, which can be calculated from norm_a as:
384-
# norm_v_squared = norm_a_squared - a_0^2 + v_0^2
385-
386-
# execute I - 2 / norm_v_squared * outer(v, v)
387-
{_, _, reflector_reversed} =
388-
for col_factor <- v, row_factor <- v, reduce: {0, 0, []} do
389-
{row, col, acc} ->
390-
row_factor = if is_complex, do: Complex.conjugate(row_factor), else: row_factor
391-
392-
# The current element in outer(v, v) is given by col_factor * row_factor
393-
# and the current I element is 1 when row == col
394-
identity_element = if row == col, do: 1, else: 0
395-
396-
result =
397-
if row >= prefix_threshold and col >= prefix_threshold do
398-
identity_element -
399-
scale * col_factor * row_factor
400-
else
401-
identity_element
402-
end
403-
404-
acc = [result | acc]
405-
406-
if col + 1 == target_k do
407-
{row + 1, 0, acc}
408-
else
409-
{row, col + 1, acc}
410-
end
411-
end
412-
413-
# This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k)
414-
{reflector, _, _} =
415-
for x <- reflector_reversed, reduce: {[], [], 0} do
416-
{result_acc, row_acc, col} ->
417-
row_acc = [x | row_acc]
418-
419-
if col + 1 == target_k do
420-
{[row_acc | result_acc], [], 0}
421-
else
422-
{result_acc, row_acc, col + 1}
423-
end
424-
end
425-
426-
reflector
427-
end
428-
429-
defp householder_reflector_pivot([a_0 | tail] = a, eps) when is_number(a_0) do
430-
# This is a trick so we can both calculate the norm of a_reverse and extract the
431-
# head a the same time we reverse the array
432-
# receives a_reverse as a list of numbers and returns the reflector as a
433-
# k x k matrix
434-
435-
norm_a_squared = Enum.reduce(a, 0, fn x, acc -> x * Complex.conjugate(x) + acc end)
436-
norm_a_sq_1on = norm_a_squared - a_0 * a_0
437-
438-
if norm_a_sq_1on < eps do
439-
{[1 | tail], 0, false}
440-
else
441-
v_0 =
442-
if a_0 <= 0 do
443-
a_0 - Complex.sqrt(norm_a_squared)
444-
else
445-
-norm_a_sq_1on / (a_0 + Complex.sqrt(norm_a_squared))
446-
end
447-
448-
v_0_sq = v_0 * v_0
449-
scale = 2 * v_0_sq / (norm_a_sq_1on + v_0_sq)
450-
v = [1 | Enum.map(tail, &(&1 / v_0))]
451-
{v, scale, false}
452-
end
453-
end
454-
455-
defp householder_reflector_pivot([a_0 | tail], _eps) do
456-
# complex case
457-
norm_a_sq_1on = Enum.reduce(tail, 0, &(Complex.abs_squared(&1) + &2))
458-
norm_a_sq = norm_a_sq_1on + Complex.abs_squared(a_0)
459-
norm_a = Complex.sqrt(norm_a_sq)
460-
461-
phase_a_0 = Complex.phase(a_0)
462-
alfa = Complex.exp(Complex.new(0, phase_a_0)) * norm_a
463-
464-
# u = x - alfa * e1
465-
u_0 = a_0 + alfa
466-
u = [u_0 | tail]
467-
norm_u_sq = norm_a_sq_1on + Complex.abs_squared(u_0)
468-
norm_u = Complex.sqrt(norm_u_sq)
469-
470-
v = Enum.map(u, &(&1 / norm_u))
471-
{v, 2, true}
472-
end
473-
474220
## Matrix (2-D array) manipulation
475221

476222
defp dot_matrix([], _), do: 0
@@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do
491237
end)
492238
end
493239

494-
defp dot_matrix_real(m1, m2) do
495-
Enum.map(m1, fn row ->
496-
m2
497-
|> transpose_matrix()
498-
|> Enum.map(fn col ->
499-
Enum.zip_reduce(row, col, 0, fn x, y, acc -> acc + x * y end)
500-
end)
501-
end)
502-
end
503-
504-
defp adjoint_matrix([x | _] = m) when not is_list(x) do
505-
Enum.map(m, &[Complex.conjugate(&1)])
506-
end
507-
508-
defp adjoint_matrix(m) do
509-
Enum.zip_with(m, fn cols -> Enum.map(cols, &Complex.conjugate/1) end)
510-
end
511-
512240
defp transpose_matrix([x | _] = m) when not is_list(x) do
513241
Enum.map(m, &[&1])
514242
end

0 commit comments

Comments
 (0)