Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,12 @@ name = "kirin-toolchain"
version = "0.22.0-dev"
description = "The Kirin Toolchain for building compilers and interpreters."
authors = [{ name = "Roger-luo", email = "rluo@quera.com" }]
dependencies = [
"rich>=13.7.1",
"beartype>=0.17.2",
"typing_extensions>=4.11.0",
]
dependencies = ["rich>=13.7.1", "beartype>=0.17.2", "typing_extensions>=4.11.0"]
readme = "README.md"
requires-python = ">= 3.10"

[project.optional-dependencies]
vmath = [
"numpy>1.26.0",
"scipy>=1.15.3",
]
vmath = ["numpy>1.26.0", "scipy>=1.15.3"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down Expand Up @@ -59,7 +52,7 @@ exclude = [
"dist",
"node_modules",
"venv",
"example/quantum/script.py", # Ignore specific file
"example/quantum/script.py", # Ignore specific file
]

[tool.ruff.lint]
Expand All @@ -71,8 +64,19 @@ include = ["src"]
[tool.coverage.run]
include = ["src/kirin/*"]

[tool.uv]
dev-dependencies = [
[dependency-groups]
doc = [
"griffe-kirin>=0.1.0",
"griffe-inherited-docstrings>=1.1.1",
"mike>=2.1.3",
"mkdocs>=1.6.1",
"mkdocs-gen-files>=0.5.0",
"mkdocs-literate-nav>=0.6.1",
"mkdocs-material>=9.5.44",
"mkdocs-minify-plugin>=0.8.0",
"mkdocstrings[python]>=0.27.0",
]
dev = [
"black>=24.10.0",
"coverage>=7.6.4",
"ipykernel>=6.29.5",
Expand All @@ -87,16 +91,3 @@ dev-dependencies = [
"rust-just>=1.36.0",
"tomlkit>=0.13.2",
]

[dependency-groups]
doc = [
"griffe-kirin>=0.1.0",
"griffe-inherited-docstrings>=1.1.1",
"mike>=2.1.3",
"mkdocs>=1.6.1",
"mkdocs-gen-files>=0.5.0",
"mkdocs-literate-nav>=0.6.1",
"mkdocs-material>=9.5.44",
"mkdocs-minify-plugin>=0.8.0",
"mkdocstrings[python]>=0.27.0",
]
76 changes: 0 additions & 76 deletions src/kirin/dialects/func/attrs.py

This file was deleted.

49 changes: 35 additions & 14 deletions src/kirin/dialects/func/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
from kirin.decl import info, statement
from kirin.print.printer import Printer

from .attrs import Signature, MethodType
from ._dialect import dialect


class FuncOpCallableInterface(ir.CallableStmtInterface["Function"]):

@classmethod
def get_signature(cls, stmt: "Function") -> types.FunctionType:
params_types = [arg.type for arg in stmt.body.blocks[0].args[1:]]
return types.FunctionType(tuple(params_types), stmt.return_type)

@classmethod
def get_callable_region(cls, stmt: "Function") -> ir.Region:
return stmt.body
Expand Down Expand Up @@ -44,7 +48,6 @@ class Function(ir.Statement):
{
ir.IsolatedFromAbove(),
ir.SymbolOpInterface(),
ir.HasSignature(),
FuncOpCallableInterface(),
ir.HasCFG(),
ir.SSACFG(),
Expand All @@ -54,11 +57,11 @@ class Function(ir.Statement):
"""The symbol name of the function."""
slots: tuple[str, ...] = info.attribute(default=())
"""The argument names of the function."""
signature: Signature = info.attribute()
"""The signature of the function at declaration."""
return_type: types.TypeAttribute = info.attribute()
"""The return type of the function."""
body: ir.Region = info.region(multi=True)
"""The body of the function."""
result: ir.ResultValue = info.result(MethodType)
result: ir.ResultValue = info.result(types.MethodType)
"""The result of the function."""

def print_impl(self, printer: Printer) -> None:
Expand All @@ -76,8 +79,9 @@ def print_arg(pair: tuple[str, types.TypeAttribute]):
printer.plain_print(" : ")
printer.print(pair[1])

params_type = [arg.type for arg in self.body.blocks[0].args[1:]]
printer.print_seq(
zip(self.slots, self.signature.inputs),
zip(self.slots, params_type),
emit=print_arg,
prefix="(",
suffix=")",
Expand All @@ -86,7 +90,7 @@ def print_arg(pair: tuple[str, types.TypeAttribute]):

with printer.rich(style="comment"):
printer.plain_print(" -> ")
printer.print(self.signature.output)
printer.print(self.return_type)
printer.plain_print(" ")

printer.print(self.body)
Expand All @@ -101,7 +105,6 @@ class Lambda(ir.Statement):
traits = frozenset(
{
ir.Pure(),
ir.HasSignature(),
ir.SymbolOpInterface(),
FuncOpCallableInterface(),
ir.HasCFG(),
Expand All @@ -111,11 +114,11 @@ class Lambda(ir.Statement):
sym_name: str = info.attribute()
slots: tuple[str, ...] = info.attribute(default=())
"""The argument names of the function."""
signature: Signature = info.attribute()
"""The signature of the function at declaration."""
return_type: types.TypeAttribute = info.attribute()
"""The return type of the function."""
captured: tuple[ir.SSAValue, ...] = info.argument()
body: ir.Region = info.region(multi=True)
result: ir.ResultValue = info.result(MethodType)
result: ir.ResultValue = info.result(types.MethodType)

def check(self) -> None:
assert self.body.blocks, "lambda body must not be empty"
Expand All @@ -130,9 +133,27 @@ def print_impl(self, printer: Printer) -> None:

printer.print_seq(self.captured, prefix="(", suffix=")", delim=", ")

def print_arg(pair: tuple[str, types.TypeAttribute]):
with printer.rich(style="symbol"):
printer.plain_print(pair[0])
with printer.rich(style="black"):
printer.plain_print(" : ")
printer.print(pair[1])

with printer.rich(style="bright_black"):
printer.plain_print(" -> ")
printer.print(self.signature.output)
printer.plain_print(" -> (")
params_type = [arg.type for arg in self.body.blocks[0].args[1:]]
printer.print_seq(
zip(self.slots, params_type),
emit=print_arg,
prefix="(",
suffix=")",
delim=", ",
)
printer.plain_print(" ")
printer.plain_print("-> ")
printer.print(self.return_type)
printer.plain_print(")")

printer.plain_print(" ")
printer.print(self.body)
Expand All @@ -145,7 +166,7 @@ def print_impl(self, printer: Printer) -> None:
class GetField(ir.Statement):
name = "getfield"
traits = frozenset({ir.Pure()})
obj: ir.SSAValue = info.argument(MethodType)
obj: ir.SSAValue = info.argument(types.MethodType)
field: int = info.attribute()
# NOTE: mypy somehow doesn't understand default init=False
result: ir.ResultValue = info.result(init=False)
Expand Down
4 changes: 2 additions & 2 deletions src/kirin/dialects/func/typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@
):
body_frame, ret = interp_.call(
stmt,
types.MethodType,
types.TypeofMethodType,

Check failure on line 90 in src/kirin/dialects/func/typeinfer.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "type[TypeofMethodType]" cannot be assigned to parameter "args" of type "TypeAttribute" in function "call"   Type "type[TypeofMethodType]" is not assignable to type "TypeAttribute" (reportArgumentType)
*tuple(arg.type for arg in stmt.body.blocks[0].args[1:]),
)
argtypes = tuple(arg.type for arg in stmt.body.blocks[0].args[1:])
ret = types.MethodType[[*argtypes], ret]
ret = types.TypeofMethodType[[*argtypes], ret]

Check failure on line 94 in src/kirin/dialects/func/typeinfer.py

View workflow job for this annotation

GitHub Actions / pyright

Expected no type arguments for class "TypeofMethodType" (reportInvalidTypeArguments)
frame.entries.update(body_frame.entries) # pass results back to upper frame
self_ = stmt.body.blocks[0].args[0]
frame.set(self_, ret)

Check failure on line 97 in src/kirin/dialects/func/typeinfer.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "type[TypeofMethodType]" cannot be assigned to parameter "value" of type "TypeAttribute" in function "set"   Type "type[TypeofMethodType]" is not assignable to type "TypeAttribute" (reportArgumentType)
return (ret,)

@impl(GetField)
Expand Down
2 changes: 1 addition & 1 deletion src/kirin/dialects/ilist/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Push(ir.Statement):
class Map(ir.Statement):
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
purity: bool = info.attribute(default=False)
fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], OutElemT])
fn: ir.SSAValue = info.argument(types.TypeofMethodType[[ElemT], OutElemT])
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
result: ir.ResultValue = info.result(IListType[OutElemT, ListLen])

Expand Down
2 changes: 1 addition & 1 deletion src/kirin/dialects/py/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Constant(ir.Statement, Generic[T]):
def __init__(self, value: T | ir.Data[T]) -> None:
if isinstance(value, ir.Method):
value = ir.PyAttr(
value, pytype=types.MethodType[list(value.arg_types), value.return_type]
value, pytype=types.TypeofMethodType[list(value.arg_types), value.return_type]
)
elif not isinstance(value, ir.Data):
value = ir.PyAttr(value)
Expand Down
4 changes: 3 additions & 1 deletion src/kirin/ir/attrs/_types.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from .abc import Attribute
from .types import Union, Generic, Literal, PyClass, TypeVar, TypeAttribute
from .types import Union, Generic, Literal, PyClass, TypeVar, TypeofMethodType, FunctionType, TypeAttribute

@dataclass
class _TypeAttribute(Attribute):
Expand All @@ -10,4 +10,6 @@
def is_subseteq_TypeVar(self, other: TypeVar) -> bool: ...
def is_subseteq_PyClass(self, other: PyClass) -> bool: ...
def is_subseteq_Generic(self, other: Generic) -> bool: ...
def is_subseteq_TypeofMethodType(self, other: TypeofMethodType) -> bool: ...
def is_subseteq_MethodType(self, other: FunctionType) -> bool: ...
def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ...
Loading
Loading