Skip to content

Commit fc88358

Browse files
authored
backport __getitem__ on grid (#28)
* getitem * Adding getitem
1 parent 96d54e8 commit fc88358

File tree

1 file changed

+72
-1
lines changed
  • src/bloqade/geometry/dialects/grid

1 file changed

+72
-1
lines changed

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dataclasses
22
from functools import cached_property
33
from itertools import chain, product
4-
from typing import Any, Generic, Sequence, TypeVar
4+
from typing import Any, Generic, Literal, Sequence, TypeVar, overload
55

66
from kirin import ir, types
77
from kirin.dialects import ilist
@@ -11,6 +11,26 @@
1111
NumY = TypeVar("NumY")
1212

1313

14+
def get_indices(size: int, index: Any) -> ilist.IList[int, Any]:
15+
if isinstance(index, slice):
16+
return ilist.IList(range(size)[index])
17+
elif isinstance(index, slice):
18+
slice_value = index
19+
return ilist.IList(range(size)[slice_value])
20+
elif isinstance(index, int):
21+
if index < 0:
22+
index += size
23+
24+
if index < 0 or index >= size:
25+
raise IndexError("Index out of range")
26+
27+
return ilist.IList([index])
28+
elif isinstance(index, ilist.IList):
29+
return index
30+
else:
31+
raise TypeError("Index must be an int, slice, or IList")
32+
33+
1434
@dataclasses.dataclass
1535
class Grid(ir.Data["Grid"], Generic[NumX, NumY]):
1636
x_spacing: tuple[float, ...]
@@ -198,6 +218,57 @@ def get_view(
198218
"""
199219
return SubGrid(parent=self, x_indices=x_indices, y_indices=y_indices)
200220

221+
@overload
222+
def __getitem__(
223+
self, indices: tuple[int, int]
224+
) -> "Grid[Literal[1], Literal[1]]": ...
225+
@overload
226+
def __getitem__(
227+
self, indices: tuple[int, slice | list[int]]
228+
) -> "Grid[Literal[1], Any]": ...
229+
230+
@overload
231+
def __getitem__(
232+
self, indices: tuple[int, ilist.IList[int, Ny]]
233+
) -> "Grid[Literal[1], Ny]": ...
234+
@overload
235+
def __getitem__(
236+
self, indices: tuple[slice | list[int], int]
237+
) -> "Grid[Any, Literal[1]]": ...
238+
@overload
239+
def __getitem__(
240+
self, indices: tuple[slice | list[int], slice]
241+
) -> "Grid[Any, Any]": ...
242+
243+
@overload
244+
def __getitem__(
245+
self, indices: tuple[slice | list[int], ilist.IList[int, Ny]]
246+
) -> "Grid[Any, Ny]": ...
247+
@overload
248+
def __getitem__(
249+
self, indices: tuple[ilist.IList[int, Nx], int]
250+
) -> "Grid[Nx, Literal[1]]": ...
251+
252+
@overload
253+
def __getitem__(
254+
self, indices: tuple[ilist.IList[int, Nx], slice | list[int]]
255+
) -> "Grid[Nx, Any]": ...
256+
257+
@overload
258+
def __getitem__(
259+
self, indices: tuple[ilist.IList[int, Nx], ilist.IList[int, Ny]]
260+
) -> "Grid[Nx, Ny]": ...
261+
262+
def __getitem__(self, indices):
263+
if len(indices) != 2:
264+
raise IndexError("Grid indexing requires two indices (x, y)")
265+
266+
x_index, y_index = indices
267+
x_indices = get_indices(len(self.x_spacing) + 1, x_index)
268+
y_indices = get_indices(len(self.y_spacing) + 1, y_index)
269+
270+
return self.get_view(x_indices=x_indices, y_indices=y_indices)
271+
201272
def __hash__(self) -> int:
202273
return id(self)
203274

0 commit comments

Comments
 (0)