|
1 | 1 | import dataclasses |
2 | 2 | from functools import cached_property |
3 | 3 | from itertools import chain, product |
4 | | -from typing import Any, Generic, Sequence, TypeVar |
| 4 | +from typing import Any, Generic, Literal, Sequence, TypeVar, overload |
5 | 5 |
|
6 | 6 | from kirin import ir, types |
7 | 7 | from kirin.dialects import ilist |
|
11 | 11 | NumY = TypeVar("NumY") |
12 | 12 |
|
13 | 13 |
|
| 14 | +def get_indices(size: int, index: Any) -> ilist.IList[int, Any]: |
| 15 | + if isinstance(index, slice): |
| 16 | + return ilist.IList(range(size)[index]) |
| 17 | + elif isinstance(index, slice): |
| 18 | + slice_value = index |
| 19 | + return ilist.IList(range(size)[slice_value]) |
| 20 | + elif isinstance(index, int): |
| 21 | + if index < 0: |
| 22 | + index += size |
| 23 | + |
| 24 | + if index < 0 or index >= size: |
| 25 | + raise IndexError("Index out of range") |
| 26 | + |
| 27 | + return ilist.IList([index]) |
| 28 | + elif isinstance(index, ilist.IList): |
| 29 | + return index |
| 30 | + else: |
| 31 | + raise TypeError("Index must be an int, slice, or IList") |
| 32 | + |
| 33 | + |
14 | 34 | @dataclasses.dataclass |
15 | 35 | class Grid(ir.Data["Grid"], Generic[NumX, NumY]): |
16 | 36 | x_spacing: tuple[float, ...] |
@@ -198,6 +218,57 @@ def get_view( |
198 | 218 | """ |
199 | 219 | return SubGrid(parent=self, x_indices=x_indices, y_indices=y_indices) |
200 | 220 |
|
| 221 | + @overload |
| 222 | + def __getitem__( |
| 223 | + self, indices: tuple[int, int] |
| 224 | + ) -> "Grid[Literal[1], Literal[1]]": ... |
| 225 | + @overload |
| 226 | + def __getitem__( |
| 227 | + self, indices: tuple[int, slice | list[int]] |
| 228 | + ) -> "Grid[Literal[1], Any]": ... |
| 229 | + |
| 230 | + @overload |
| 231 | + def __getitem__( |
| 232 | + self, indices: tuple[int, ilist.IList[int, Ny]] |
| 233 | + ) -> "Grid[Literal[1], Ny]": ... |
| 234 | + @overload |
| 235 | + def __getitem__( |
| 236 | + self, indices: tuple[slice | list[int], int] |
| 237 | + ) -> "Grid[Any, Literal[1]]": ... |
| 238 | + @overload |
| 239 | + def __getitem__( |
| 240 | + self, indices: tuple[slice | list[int], slice] |
| 241 | + ) -> "Grid[Any, Any]": ... |
| 242 | + |
| 243 | + @overload |
| 244 | + def __getitem__( |
| 245 | + self, indices: tuple[slice | list[int], ilist.IList[int, Ny]] |
| 246 | + ) -> "Grid[Any, Ny]": ... |
| 247 | + @overload |
| 248 | + def __getitem__( |
| 249 | + self, indices: tuple[ilist.IList[int, Nx], int] |
| 250 | + ) -> "Grid[Nx, Literal[1]]": ... |
| 251 | + |
| 252 | + @overload |
| 253 | + def __getitem__( |
| 254 | + self, indices: tuple[ilist.IList[int, Nx], slice | list[int]] |
| 255 | + ) -> "Grid[Nx, Any]": ... |
| 256 | + |
| 257 | + @overload |
| 258 | + def __getitem__( |
| 259 | + self, indices: tuple[ilist.IList[int, Nx], ilist.IList[int, Ny]] |
| 260 | + ) -> "Grid[Nx, Ny]": ... |
| 261 | + |
| 262 | + def __getitem__(self, indices): |
| 263 | + if len(indices) != 2: |
| 264 | + raise IndexError("Grid indexing requires two indices (x, y)") |
| 265 | + |
| 266 | + x_index, y_index = indices |
| 267 | + x_indices = get_indices(len(self.x_spacing) + 1, x_index) |
| 268 | + y_indices = get_indices(len(self.y_spacing) + 1, y_index) |
| 269 | + |
| 270 | + return self.get_view(x_indices=x_indices, y_indices=y_indices) |
| 271 | + |
201 | 272 | def __hash__(self) -> int: |
202 | 273 | return id(self) |
203 | 274 |
|
|
0 commit comments