Skip to content

Commit aa1c7fc

Browse files
authored
adding equality and hash to grid (#32)
1 parent 87e53f7 commit aa1c7fc

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ def get_indices(size: int, index: Any) -> ilist.IList[int, Any]:
1515
if isinstance(index, slice):
1616
return ilist.IList(range(size)[index])
1717
elif isinstance(index, slice):
18-
slice_value = index
19-
return ilist.IList(range(size)[slice_value])
18+
return ilist.IList(range(size)[index])
2019
elif isinstance(index, int):
2120
if index < 0:
2221
index += size
@@ -134,7 +133,7 @@ def x_bounds(self):
134133
135134
"""
136135
if self.x_init is None:
137-
return (None, None)
136+
raise ValueError("x_init is None, cannot compute bounds")
138137

139138
return (self.x_init, self.x_init + self.width)
140139

@@ -146,7 +145,7 @@ def y_bounds(self):
146145
147146
"""
148147
if self.y_init is None:
149-
return (None, None)
148+
raise ValueError("y_init is None, cannot compute bounds")
150149

151150
return (self.y_init, self.y_init + self.height)
152151

@@ -298,7 +297,16 @@ def __getitem__(self, indices):
298297
return self.get_view(x_indices=x_indices, y_indices=y_indices)
299298

300299
def __hash__(self) -> int:
301-
return id(self)
300+
return hash((self.x_spacing, self.y_spacing, self.x_init, self.y_init))
301+
302+
def __eq__(self, other: Any) -> bool:
303+
return (
304+
isinstance(other, Grid)
305+
and self.x_spacing == other.x_spacing
306+
and self.y_spacing == other.y_spacing
307+
and self.x_init == other.x_init
308+
and self.y_init == other.y_init
309+
)
302310

303311
def print_impl(self, printer: Printer) -> None:
304312
printer.plain_print("Grid(")
@@ -453,7 +461,10 @@ def get_view(self, x_indices, y_indices):
453461
)
454462

455463
def __hash__(self):
456-
return id(self)
464+
return super().__hash__()
465+
466+
def __eq__(self, other: Any) -> bool:
467+
return super().__eq__(other)
457468

458469
def __repr__(self):
459470
return super().__repr__()

test/grid/test_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ def test_empty_positions(self):
127127
grid_obj = Grid.from_positions([], [1])
128128
assert grid_obj.x_positions == ()
129129
assert grid_obj.y_positions == (1,)
130-
assert grid_obj.x_bounds() == (None, None)
130+
131+
with pytest.raises(ValueError):
132+
grid_obj.x_bounds()
133+
131134
assert grid_obj.y_bounds() == (1, 1)
132135
assert grid_obj.width == 0
133136
assert grid_obj.height == 0

0 commit comments

Comments
 (0)