|
4 | 4 |
|
5 | 5 | # These functions are in both the main and linalg namespaces |
6 | 6 | from cubed.array_api.data_type_functions import result_type |
| 7 | +from cubed.array_api.dtypes import _floating_dtypes |
7 | 8 | from cubed.array_api.linear_algebra_functions import ( # noqa: F401 |
8 | 9 | matmul, |
9 | 10 | matrix_transpose, |
@@ -33,6 +34,9 @@ def qr(x, /, *, mode="reduced") -> QRResult: |
33 | 34 | if mode != "reduced": |
34 | 35 | raise ValueError("qr only supports mode='reduced'") |
35 | 36 |
|
| 37 | + if x.dtype not in _floating_dtypes: |
| 38 | + raise TypeError("Only floating-point dtypes are allowed in qr") |
| 39 | + |
36 | 40 | if x.numblocks[1] > 1: |
37 | 41 | raise ValueError( |
38 | 42 | "qr only supports tall-and-skinny (single column chunk) arrays. " |
@@ -80,7 +84,7 @@ def _qr_first_step(A): |
80 | 84 | nxp.linalg.qr, |
81 | 85 | A, |
82 | 86 | shapes=[A.shape, R1_shape], |
83 | | - dtypes=[nxp.float64, nxp.float64], |
| 87 | + dtypes=[A.dtype, A.dtype], |
84 | 88 | chunkss=[A.chunks, R1_chunks], |
85 | 89 | extra_projected_mem=extra_projected_mem, |
86 | 90 | ) |
@@ -119,7 +123,7 @@ def _qr_second_step(R1): |
119 | 123 | nxp.linalg.qr, |
120 | 124 | R1_single, |
121 | 125 | shapes=[Q2_shape, R2_shape], |
122 | | - dtypes=[nxp.float64, nxp.float64], |
| 126 | + dtypes=[R1.dtype, R1.dtype], |
123 | 127 | chunkss=[Q2_chunks, R2_chunks], |
124 | 128 | extra_projected_mem=extra_projected_mem, |
125 | 129 | ) |
@@ -148,7 +152,7 @@ def _qr_third_step(Q1, Q2): |
148 | 152 | Q1, |
149 | 153 | Q2, |
150 | 154 | shape=Q1_shape, |
151 | | - dtype=nxp.float64, |
| 155 | + dtype=result_type(Q1, Q2), |
152 | 156 | chunks=Q1_chunks, |
153 | 157 | extra_projected_mem=extra_projected_mem, |
154 | 158 | q1_chunks=Q1_chunks, |
|
0 commit comments