|
3 | 3 | ABC, |
4 | 4 | abstractmethod, |
5 | 5 | ) |
| 6 | +from functools import ( |
| 7 | + wraps, |
| 8 | +) |
6 | 9 | from typing import ( |
7 | 10 | Any, |
| 11 | + Callable, |
8 | 12 | Optional, |
| 13 | + overload, |
9 | 14 | ) |
10 | 15 |
|
11 | 16 | import array_api_compat |
@@ -116,6 +121,105 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: |
116 | 121 | return np.from_dlpack(x) |
117 | 122 |
|
118 | 123 |
|
| 124 | +def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: |
| 125 | + """A decorator that casts and casts back the input |
| 126 | + and output tensor of a method. |
| 127 | +
|
| 128 | + The decorator should be used on an instance method. |
| 129 | +
|
| 130 | + The decorator will do the following thing: |
| 131 | + (1) It casts input arrays from the global precision |
| 132 | + to precision defined by property `precision`. |
| 133 | + (2) It casts output arrays from `precision` to |
| 134 | + the global precision. |
| 135 | + (3) It checks inputs and outputs and only casts when |
| 136 | + input or output is an array and its dtype matches |
| 137 | + the global precision and `precision`, respectively. |
| 138 | + If it does not match (e.g. it is an integer), the decorator |
| 139 | + will do nothing on it. |
| 140 | +
|
| 141 | + The decorator supports the array API. |
| 142 | +
|
| 143 | + Returns |
| 144 | + ------- |
| 145 | + Callable |
| 146 | + a decorator that casts and casts back the input and |
| 147 | + output array of a method |
| 148 | +
|
| 149 | + Examples |
| 150 | + -------- |
| 151 | + >>> class A: |
| 152 | + ... def __init__(self): |
| 153 | + ... self.precision = "float32" |
| 154 | + ... |
| 155 | + ... @cast_precision |
| 156 | + ... def f(x: Array, y: Array) -> Array: |
| 157 | + ... return x**2 + y |
| 158 | + """ |
| 159 | + |
| 160 | + @wraps(func) |
| 161 | + def wrapper(self, *args, **kwargs): |
| 162 | + # only convert tensors |
| 163 | + returned_tensor = func( |
| 164 | + self, |
| 165 | + *[safe_cast_array(vv, "global", self.precision) for vv in args], |
| 166 | + **{ |
| 167 | + kk: safe_cast_array(vv, "global", self.precision) |
| 168 | + for kk, vv in kwargs.items() |
| 169 | + }, |
| 170 | + ) |
| 171 | + if isinstance(returned_tensor, tuple): |
| 172 | + return tuple( |
| 173 | + safe_cast_array(vv, self.precision, "global") for vv in returned_tensor |
| 174 | + ) |
| 175 | + elif isinstance(returned_tensor, dict): |
| 176 | + return { |
| 177 | + kk: safe_cast_array(vv, self.precision, "global") |
| 178 | + for kk, vv in returned_tensor.items() |
| 179 | + } |
| 180 | + else: |
| 181 | + return safe_cast_array(returned_tensor, self.precision, "global") |
| 182 | + |
| 183 | + return wrapper |
| 184 | + |
| 185 | + |
| 186 | +@overload |
| 187 | +def safe_cast_array( |
| 188 | + input: np.ndarray, from_precision: str, to_precision: str |
| 189 | +) -> np.ndarray: ... |
| 190 | +@overload |
| 191 | +def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ... |
| 192 | +def safe_cast_array( |
| 193 | + input: Optional[np.ndarray], from_precision: str, to_precision: str |
| 194 | +) -> Optional[np.ndarray]: |
| 195 | + """Convert an array from a precision to another precision. |
| 196 | +
|
| 197 | + If input is not an array or without the specific precision, the method will not |
| 198 | + cast it. |
| 199 | +
|
| 200 | + Array API is supported. |
| 201 | +
|
| 202 | + Parameters |
| 203 | + ---------- |
| 204 | + input : np.ndarray or None |
| 205 | + Input array |
| 206 | + from_precision : str |
| 207 | + Array data type that is casted from |
| 208 | + to_precision : str |
| 209 | + Array data type that casts to |
| 210 | +
|
| 211 | + Returns |
| 212 | + ------- |
| 213 | + np.ndarray or None |
| 214 | + casted array |
| 215 | + """ |
| 216 | + if array_api_compat.is_array_api_obj(input): |
| 217 | + xp = array_api_compat.array_namespace(input) |
| 218 | + if input.dtype == get_xp_precision(xp, from_precision): |
| 219 | + return xp.astype(input, get_xp_precision(xp, to_precision)) |
| 220 | + return input |
| 221 | + |
| 222 | + |
119 | 223 | __all__ = [ |
120 | 224 | "GLOBAL_NP_FLOAT_PRECISION", |
121 | 225 | "GLOBAL_ENER_FLOAT_PRECISION", |
|
0 commit comments