-
Notifications
You must be signed in to change notification settings - Fork 213
Create blocked Jacobi method for eigen decomposition #1510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f003cc1
6568f92
7428fb8
4899d0a
e288bf5
b5bc4f9
289f964
ea02ed2
9bbb478
d0ea621
d5f454e
dcb52cb
318448f
941e606
eade22d
bd246f7
031d74e
fa9e4a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,254 @@ | ||||||
| defmodule Nx.LinAlg.BlockEigh do | ||||||
| @moduledoc """ | ||||||
| Parallel Jacobi symmetric eigendecomposition. | ||||||
|
|
||||||
| Reference implementation taking from XLA's eigh_expander | ||||||
| which is built on the approach in: | ||||||
| Brent, R. P., & Luk, F. T. (1985). The solution of singular-value | ||||||
| and symmetric eigenvalue problems on multiprocessor arrays. | ||||||
| SIAM Journal on Computing, 6(1), 69-84. https://doi.org/10.1137/0906007 | ||||||
| """ | ||||||
| require Nx | ||||||
|
|
||||||
| import Nx.Defn | ||||||
|
|
||||||
| defn calc_rot(tl, tr, br) do | ||||||
| a = Nx.take_diagonal(br) | ||||||
| b = Nx.take_diagonal(tr) | ||||||
| c = Nx.take_diagonal(tl) | ||||||
|
|
||||||
| tau = (a - c) / (2 * b) | ||||||
| t = Nx.sqrt(1 + Nx.pow(tau, 2)) | ||||||
| t = Nx.select(Nx.greater_equal(tau, 0), 1 / (tau + t), 1 / (tau - t)) | ||||||
christianjgreen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c))) | ||||||
|
||||||
| pred = Nx.less_equal(Nx.abs(b), 0.1 * 1.0e-4 * Nx.min(Nx.abs(a), Nx.abs(c))) | |
| pred = Nx.abs(b) <= 1.0e-5 * Nx.min(Nx.abs(a), Nx.abs(c)) |
christianjgreen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
christianjgreen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
christianjgreen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should be the only defn in this module and the others would be defnp. Or something close to that.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why force f32 here? Is this another case where the algorithm just fails on f64?
Perhaps this should be masked underneath the implementation if it's the case.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use a pattern for organizing the while state that we do quite a lot:
{{tl, br, v_tl, v_tr, v_bl, v_br}, _} where you leave the outputs in a first-position tuple, and the other state in a second position, so pattern matching on the statement is easier, as well as understanding what's output and what's not
Uh oh!
There was an error while loading. Please reload this page.