1- #include " custom_calls.h"
2-
3- #include " Eigen/Dense"
4- #include " Eigen/Eigenvalues"
5- #include " Eigen/QR"
6- #include " exla_nif_util.h"
71#include " xla/service/custom_call_target_registry.h"
82
9- template <typename DataType>
10- void single_matrix_eigh_cpu_custom_call (DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) {
11- typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
12-
13- // Map the input matrix
14- Eigen::Map<RowMajorMatrix> input (in, m, n);
15-
16- // Compute the Eigenvalue decomposition
17- Eigen::SelfAdjointEigenSolver<RowMajorMatrix> eigensolver (input);
18-
19- if (eigensolver.info () != Eigen::Success) {
20- std::cerr << " Eigenvalue decomposition failed!" << std::endl;
21- return ;
22- }
23-
24- // Get the eigenvalues and eigenvectors
25- Eigen::Matrix<DataType, Eigen::Dynamic, 1 > eigenvalues = eigensolver.eigenvalues ();
26- RowMajorMatrix eigenvectors = eigensolver.eigenvectors ();
27-
28- // Copy the eigenvalues to the output
29- std::memcpy (eigenvalues_out, eigenvalues.data (), m * sizeof (DataType));
30-
31- // Copy the eigenvectors to the output
32- std::memcpy (eigenvectors_out, eigenvectors.data (), m * n * sizeof (DataType));
33- }
34-
35- template <typename DataType>
36- void single_matrix_qr_cpu_custom_call (DataType *q_out, DataType *r_out, DataType *in, uint64_t m, uint64_t k, uint64_t n, bool complete) {
37- typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
38-
39- Eigen::Map<RowMajorMatrix> input (in, m, n);
40- Eigen::HouseholderQR<RowMajorMatrix> qr = input.householderQr ();
41-
42- RowMajorMatrix Q, R;
43- size_t num_bytes_q, num_bytes_r;
44-
45- if (complete) {
46- Q = qr.householderQ () * RowMajorMatrix::Identity (m, m);
47- R = qr.matrixQR ();
48-
49- num_bytes_q = m * m * sizeof (DataType);
50-
51- for (uint64_t i = 0 ; i < m; ++i) {
52- for (uint64_t j = 0 ; j < n; ++j) {
53- r_out[i * n + j] = (j >= i) ? R (i, j) : static_cast <DataType>(0.0 );
54- }
55- }
56- } else {
57- Q = qr.householderQ () * RowMajorMatrix::Identity (m, k);
58- R = qr.matrixQR ().topRows (k);
59-
60- num_bytes_q = m * k * sizeof (DataType);
61-
62- for (uint64_t i = 0 ; i < k; ++i) {
63- for (uint64_t j = 0 ; j < n; ++j) {
64- r_out[i * n + j] = (j >= i) ? R (i, j) : static_cast <DataType>(0.0 );
65- }
66- }
67- }
68-
69- memcpy (q_out, Q.data (), num_bytes_q);
70- }
71-
72- template <typename DataType>
73- void qr_cpu_custom_call (void *out[], const void *in[]) {
74- DataType *operand = (DataType *)in[0 ];
75-
76- uint64_t *dim_sizes = (uint64_t *)in[1 ];
77- uint64_t num_operand_dims = dim_sizes[0 ];
78- uint64_t num_q_dims = dim_sizes[1 ];
79- uint64_t num_r_dims = dim_sizes[2 ];
80-
81- uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
82- std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
83-
84- uint64_t *q_dims_ptr = (uint64_t *)in[3 ];
85- std::vector<uint64_t > q_dims (q_dims_ptr, q_dims_ptr + num_q_dims);
86-
87- uint64_t *r_dims_ptr = (uint64_t *)in[4 ];
88- std::vector<uint64_t > r_dims (r_dims_ptr, r_dims_ptr + num_r_dims);
89-
90- uint64_t m = q_dims[q_dims.size () - 2 ];
91- uint64_t k = q_dims[q_dims.size () - 1 ];
92- uint64_t n = r_dims[r_dims.size () - 1 ];
93- bool complete = r_dims[r_dims.size () - 2 ] == m;
94-
95- auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
96-
97- uint64_t batch_items = 1 ;
98- for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
99- batch_items *= leading_dimensions[i];
100- }
3+ void qr_cpu_custom_call_f32 (void *out[], const void *in[]);
4+ void qr_cpu_custom_call_f64 (void *out[], const void *in[]);
5+ void qr_cpu_custom_call_f16 (void *out[], const void *in[]);
6+ void qr_cpu_custom_call_bf16 (void *out[], const void *in[]);
7+ void eigh_cpu_custom_call_f32 (void *out[], const void *in[]);
8+ void eigh_cpu_custom_call_f64 (void *out[], const void *in[]);
1019
102- DataType *q = (DataType *)out[0 ];
103- DataType *r = (DataType *)out[1 ];
104-
105- uint64_t r_stride = r_dims[r_dims.size () - 1 ] * r_dims[r_dims.size () - 2 ] * sizeof (DataType);
106- uint64_t q_stride = q_dims[q_dims.size () - 1 ] * q_dims[q_dims.size () - 2 ] * sizeof (DataType);
107- uint64_t inner_stride = m * n * sizeof (DataType);
108-
109- for (uint64_t i = 0 ; i < batch_items; i++) {
110- single_matrix_qr_cpu_custom_call<DataType>(
111- (DataType *)out[0 ] + i * q_stride,
112- (DataType *)out[1 ] + i * r_stride,
113- operand + i * inner_stride * sizeof (DataType),
114- m, k, n, complete);
115- }
116- }
117-
118- template <typename DataType>
119- void eigh_cpu_custom_call (void *out[], const void *in[]) {
120- DataType *operand = (DataType *)in[0 ];
121-
122- uint64_t *dim_sizes = (uint64_t *)in[1 ];
123- uint64_t num_operand_dims = dim_sizes[0 ];
124- uint64_t num_eigenvalues_dims = dim_sizes[1 ];
125- uint64_t num_eigenvectors_dims = dim_sizes[2 ];
126-
127- uint64_t *operand_dims_ptr = (uint64_t *)in[2 ];
128- std::vector<uint64_t > operand_dims (operand_dims_ptr, operand_dims_ptr + num_operand_dims);
129-
130- uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3 ];
131- std::vector<uint64_t > eigenvalues_dims (eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
132-
133- uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4 ];
134- std::vector<uint64_t > eigenvectors_dims (eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
135-
136- uint64_t m = eigenvectors_dims[eigenvectors_dims.size () - 2 ];
137- uint64_t n = eigenvectors_dims[eigenvectors_dims.size () - 1 ];
138-
139- auto leading_dimensions = std::vector<uint64_t >(operand_dims.begin (), operand_dims.end () - 2 );
140-
141- uint64_t batch_items = 1 ;
142- for (uint64_t i = 0 ; i < leading_dimensions.size (); i++) {
143- batch_items *= leading_dimensions[i];
144- }
145-
146- DataType *eigenvalues = (DataType *)out[0 ];
147- DataType *eigenvectors = (DataType *)out[1 ];
148-
149- uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size () - 1 ] * sizeof (DataType);
150- uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size () - 1 ] * eigenvectors_dims[eigenvectors_dims.size () - 2 ] * sizeof (DataType);
151- uint64_t inner_stride = m * n * sizeof (DataType);
152-
153- for (uint64_t i = 0 ; i < batch_items; i++) {
154- single_matrix_eigh_cpu_custom_call<DataType>(
155- eigenvalues + i * eigenvalues_stride,
156- eigenvectors + i * eigenvectors_stride,
157- operand + i * inner_stride / sizeof (DataType),
158- m, n);
159- }
160- }
161-
162- void qr_cpu_custom_call_bf16 (void *out[], const void *in[]) {
163- qr_cpu_custom_call<exla::bfloat16>(out, in);
164- }
165-
166- void qr_cpu_custom_call_f16 (void *out[], const void *in[]) {
167- qr_cpu_custom_call<exla::float16>(out, in);
168- }
169-
170- void qr_cpu_custom_call_f32 (void *out[], const void *in[]) {
171- qr_cpu_custom_call<float >(out, in);
172- }
173-
174- void qr_cpu_custom_call_f64 (void *out[], const void *in[]) {
175- qr_cpu_custom_call<double >(out, in);
176- }
177-
178- void eigh_cpu_custom_call_f32 (void *out[], const void *in[]) {
179- eigh_cpu_custom_call<float >(out, in);
180- }
181-
182- void eigh_cpu_custom_call_f64 (void *out[], const void *in[]) {
183- eigh_cpu_custom_call<double >(out, in);
184- }
185-
186- XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f32" , qr_cpu_custom_call_f32);
18710XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f64" , qr_cpu_custom_call_f64);
11+ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f32" , qr_cpu_custom_call_f32);
18812XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_f16" , qr_cpu_custom_call_f16);
18913XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" qr_cpu_custom_call_bf16" , qr_cpu_custom_call_bf16);
190-
191-
192- XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f32" , eigh_cpu_custom_call_f32);
19314XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f64" , eigh_cpu_custom_call_f64);
15+ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM (" eigh_cpu_custom_call_f32" , eigh_cpu_custom_call_f32);
0 commit comments