Skip to content

Commit 4b6bce2

Browse files
committed
Adding typeinfer from grid.New statement. (#1)
* adding type infer for * making type check a bit tighter bound
1 parent ecf38fb commit 4b6bce2

File tree

7 files changed

+100
-3
lines changed

7 files changed

+100
-3
lines changed

src/bloqade/geometry/dialects/grid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._dialect import dialect as dialect
2+
from ._typeinfer import TypeInferMethods as TypeInferMethods
23
from .concrete import GridInterpreter as GridInterpreter
34
from .stmts import (
45
FromPositions as FromPositions,
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import cast
2+
3+
from kirin import types
4+
from kirin.analysis import TypeInference
5+
from kirin.dialects import ilist
6+
from kirin.interp import Frame, MethodTable, impl
7+
8+
from ._dialect import dialect
9+
from .stmts import New
10+
from .types import GridType
11+
12+
13+
@dialect.register(key="typeinfer")
14+
class TypeInferMethods(MethodTable):
15+
16+
def get_len(self, typ: types.TypeAttribute):
17+
if typ.is_subseteq(ilist.IListType[types.Int, types.Any]):
18+
typ = cast(types.Generic, typ)
19+
if isinstance(typ.vars[1], types.Literal):
20+
return types.Literal(typ.vars[1].data + 1)
21+
22+
return types.Any
23+
24+
@impl(New)
25+
def inter_new(self, _: TypeInference, frame: Frame[types.TypeAttribute], node: New):
26+
x_len = self.get_len(frame.get(node.x_spacing))
27+
y_len = self.get_len(frame.get(node.y_spacing))
28+
29+
return (GridType[x_len, y_len],)

src/bloqade/geometry/dialects/grid/stmts.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@ class New(ir.Statement):
2626
name = "new"
2727
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
2828

29-
x_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any])
30-
y_spacing: ir.SSAValue = info.argument(type=ilist.IListType[types.Float, types.Any])
29+
x_spacing: ir.SSAValue = info.argument(
30+
type=ilist.IListType[types.Float, types.TypeVar("NumXStep")]
31+
)
32+
y_spacing: ir.SSAValue = info.argument(
33+
type=ilist.IListType[types.Float, types.TypeVar("NumYStep")]
34+
)
3135
x_init: ir.SSAValue = info.argument(types.Float)
3236
y_init: ir.SSAValue = info.argument(types.Float)
33-
result: ir.ResultValue = info.result(GridType[types.Any, types.Any])
37+
result: ir.ResultValue = info.result(
38+
GridType[types.TypeVar("NumX"), types.TypeVar("NumY")]
39+
)
3440

3541

3642
# Maybe do this with hints?

src/bloqade/geometry/prelude.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from kirin import ir
2+
from kirin.ir.method import Method
3+
from kirin.passes.default import Default
4+
from kirin.prelude import structural
5+
from typing_extensions import Annotated, Doc
6+
7+
from bloqade.geometry.dialects import grid
8+
9+
10+
@ir.dialect_group(structural.add(grid))
11+
def geometry(
12+
self,
13+
):
14+
"""Structural kernel with optimization passes."""
15+
16+
def run_pass(
17+
mt: Annotated[Method, Doc("The method to run pass on.")],
18+
*,
19+
verify: Annotated[
20+
bool, Doc("run `verify` before running passes, default is `True`")
21+
] = True,
22+
typeinfer: Annotated[
23+
bool,
24+
Doc(
25+
"run type inference and apply the inferred type to IR, default `False`"
26+
),
27+
] = False,
28+
fold: Annotated[bool, Doc("run folding passes")] = True,
29+
aggressive: Annotated[
30+
bool, Doc("run aggressive folding passes if `fold=True`")
31+
] = False,
32+
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
33+
) -> None:
34+
default_pass = Default(
35+
self,
36+
verify=verify,
37+
fold=fold,
38+
aggressive=aggressive,
39+
typeinfer=typeinfer,
40+
no_raise=no_raise,
41+
)
42+
default_pass.fixpoint(mt)
43+
44+
return run_pass

src/bloqade/geometry/rewrite/__init__.py

Whitespace-only changes.

src/bloqade/geometry/rewrite/desugar.py

Whitespace-only changes.

test/grid/test_typeinfer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from kirin import types
2+
3+
from bloqade.geometry.dialects import grid
4+
from bloqade.geometry.prelude import geometry
5+
6+
7+
def test_typeinfer():
8+
9+
@geometry
10+
def test_method():
11+
return grid.New([1, 2], [1, 2], 0, 0)
12+
13+
test_method.return_type.is_equal(grid.GridType[types.Literal(3), types.Literal(3)])
14+
15+
16+
if __name__ == "__main__":
17+
test_typeinfer()

0 commit comments

Comments
 (0)