Skip to content

Commit 1f09855

Browse files
authored
[mypyc] Match int arguments to primitives with native int params (#20299)
Previously `int` argument values didn't match if a primitive parameter type was `u8`, for example, requiring generic operations to be used, or an explicit `u8` conversion/cast/type annotation. Now we perform a second matching pass with relaxed subtyping rules that allow `int` to match with any native int type if strict matching didn't produce a result. As an optimization, only perform this if any primitives use native integer types (most don't).
1 parent 9250c1b commit 1f09855

File tree

4 files changed

+77
-23
lines changed

4 files changed

+77
-23
lines changed

mypyc/ir/ops.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ class to enable the new behavior. Sometimes adding a new abstract
3737
RStruct,
3838
RTuple,
3939
RType,
40+
RUnion,
4041
RVoid,
4142
bit_rprimitive,
4243
bool_rprimitive,
4344
cstring_rprimitive,
4445
float_rprimitive,
4546
int_rprimitive,
4647
is_bool_or_bit_rprimitive,
48+
is_fixed_width_rtype,
4749
is_int_rprimitive,
4850
is_none_rprimitive,
4951
is_pointer_rprimitive,
@@ -688,7 +690,7 @@ class PrimitiveDescription:
688690
Primitives get lowered into lower-level ops before code generation.
689691
690692
If c_function_name is provided, a primitive will be lowered into a CallC op.
691-
Otherwise custom logic will need to be implemented to transform the
693+
Otherwise, custom logic will need to be implemented to transform the
692694
primitive into lower-level ops.
693695
"""
694696

@@ -737,11 +739,24 @@ def __init__(
737739
# Capsule that needs to imported and configured to call the primitive
738740
# (name of the target module, e.g. "librt.base64").
739741
self.capsule = capsule
742+
# Native integer types such as u8 can cause ambiguity in primitive
743+
# matching, since these are assignable to plain int *and* vice versa.
744+
# If this flag is set, the primitive has native integer types and must
745+
# be matched using more complex rules.
746+
self.is_ambiguous = any(has_fixed_width_int(t) for t in arg_types)
740747

741748
def __repr__(self) -> str:
742749
return f"<PrimitiveDescription {self.name!r}: {self.arg_types}>"
743750

744751

752+
def has_fixed_width_int(t: RType) -> bool:
753+
if isinstance(t, RTuple):
754+
return any(has_fixed_width_int(t) for t in t.types)
755+
elif isinstance(t, RUnion):
756+
return any(has_fixed_width_int(t) for t in t.items)
757+
return is_fixed_width_rtype(t)
758+
759+
745760
@final
746761
class PrimitiveOp(RegisterOp):
747762
"""A higher-level primitive operation.

mypyc/irbuild/ll_builder.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2209,8 +2209,14 @@ def matching_primitive_op(
22092209
args: list[Value],
22102210
line: int,
22112211
result_type: RType | None = None,
2212+
*,
22122213
can_borrow: bool = False,
2214+
strict: bool = True,
22132215
) -> Value | None:
2216+
"""Find primitive operation that is compatible with types of args.
2217+
2218+
Return None if none of them match.
2219+
"""
22142220
matching: PrimitiveDescription | None = None
22152221
for desc in candidates:
22162222
if len(desc.arg_types) != len(args):
@@ -2219,7 +2225,7 @@ def matching_primitive_op(
22192225
continue
22202226
if all(
22212227
# formal is not None and # TODO
2222-
is_subtype(actual.type, formal)
2228+
is_subtype(actual.type, formal, relaxed=not strict)
22232229
for actual, formal in zip(args, desc.arg_types)
22242230
) and (not desc.is_borrowed or can_borrow):
22252231
if matching:
@@ -2232,6 +2238,12 @@ def matching_primitive_op(
22322238
matching = desc
22332239
if matching:
22342240
return self.primitive_op(matching, args, line=line, result_type=result_type)
2241+
if strict and any(prim.is_ambiguous for prim in candidates):
2242+
# Also try a non-exact match if any primitives have ambiguous types.
2243+
return self.matching_primitive_op(
2244+
candidates, args, line, result_type, can_borrow=can_borrow, strict=False
2245+
)
2246+
22352247
return None
22362248

22372249
def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> Value:

mypyc/subtype.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@
2323
)
2424

2525

26-
def is_subtype(left: RType, right: RType) -> bool:
26+
def is_subtype(left: RType, right: RType, *, relaxed: bool = False) -> bool:
2727
if is_object_rprimitive(right):
2828
return True
2929
elif isinstance(right, RUnion):
3030
if isinstance(left, RUnion):
3131
for left_item in left.items:
32-
if not any(is_subtype(left_item, right_item) for right_item in right.items):
32+
if not any(
33+
is_subtype(left_item, right_item, relaxed=relaxed)
34+
for right_item in right.items
35+
):
3336
return False
3437
return True
3538
else:
36-
return any(is_subtype(left, item) for item in right.items)
37-
return left.accept(SubtypeVisitor(right))
39+
return any(is_subtype(left, item, relaxed=relaxed) for item in right.items)
40+
return left.accept(SubtypeVisitor(right, relaxed=relaxed))
3841

3942

4043
class SubtypeVisitor(RTypeVisitor[bool]):
@@ -44,14 +47,15 @@ class SubtypeVisitor(RTypeVisitor[bool]):
4447
is_subtype and don't need to be covered here.
4548
"""
4649

47-
def __init__(self, right: RType) -> None:
50+
def __init__(self, right: RType, *, relaxed: bool = False) -> None:
4851
self.right = right
52+
self.relaxed = relaxed
4953

5054
def visit_rinstance(self, left: RInstance) -> bool:
5155
return isinstance(self.right, RInstance) and self.right.class_ir in left.class_ir.mro
5256

5357
def visit_runion(self, left: RUnion) -> bool:
54-
return all(is_subtype(item, self.right) for item in left.items)
58+
return all(is_subtype(item, self.right, relaxed=self.relaxed) for item in left.items)
5559

5660
def visit_rprimitive(self, left: RPrimitive) -> bool:
5761
right = self.right
@@ -64,6 +68,11 @@ def visit_rprimitive(self, left: RPrimitive) -> bool:
6468
elif is_short_int_rprimitive(left):
6569
if is_int_rprimitive(right):
6670
return True
71+
if self.relaxed and is_fixed_width_rtype(right):
72+
return True
73+
elif is_int_rprimitive(left):
74+
if self.relaxed and is_fixed_width_rtype(right):
75+
return True
6776
elif is_fixed_width_rtype(left):
6877
if is_int_rprimitive(right):
6978
return True
@@ -74,7 +83,8 @@ def visit_rtuple(self, left: RTuple) -> bool:
7483
return True
7584
if isinstance(self.right, RTuple):
7685
return len(self.right.types) == len(left.types) and all(
77-
is_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types)
86+
is_subtype(t1, t2, relaxed=self.relaxed)
87+
for t1, t2 in zip(left.types, self.right.types)
7888
)
7989
return False
8090

mypyc/test-data/irbuild-librt-strings.test

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,55 @@
1-
[case testLibrtStrings_experimental]
1+
[case testLibrtStrings_experimental_64bit]
22
from librt.strings import BytesWriter
33
from mypy_extensions import u8, i64
44

55
def bytes_writer_basics() -> bytes:
66
b = BytesWriter()
7-
x: u8 = 1
8-
b.append(x)
7+
b.append(1)
98
b.write(b'foo')
10-
n: i64 = 2
9+
n = 2
1110
b.truncate(n)
1211
return b.getvalue()
1312
def bytes_writer_len(b: BytesWriter) -> i64:
1413
return len(b)
1514
[out]
1615
def bytes_writer_basics():
1716
r0, b :: librt.strings.BytesWriter
18-
x :: u8
1917
r1 :: None
2018
r2 :: bytes
2119
r3 :: None
22-
n :: i64
23-
r4 :: None
24-
r5 :: bytes
20+
n :: int
21+
r4 :: native_int
22+
r5 :: bit
23+
r6, r7 :: i64
24+
r8 :: ptr
25+
r9 :: c_ptr
26+
r10 :: i64
27+
r11 :: None
28+
r12 :: bytes
2529
L0:
2630
r0 = LibRTStrings_BytesWriter_internal()
2731
b = r0
28-
x = 1
29-
r1 = LibRTStrings_BytesWriter_append_internal(b, x)
32+
r1 = LibRTStrings_BytesWriter_append_internal(b, 1)
3033
r2 = b'foo'
3134
r3 = LibRTStrings_BytesWriter_write_internal(b, r2)
32-
n = 2
33-
r4 = LibRTStrings_BytesWriter_truncate_internal(b, n)
34-
r5 = LibRTStrings_BytesWriter_getvalue_internal(b)
35-
return r5
35+
n = 4
36+
r4 = n & 1
37+
r5 = r4 == 0
38+
if r5 goto L1 else goto L2 :: bool
39+
L1:
40+
r6 = n >> 1
41+
r7 = r6
42+
goto L3
43+
L2:
44+
r8 = n ^ 1
45+
r9 = r8
46+
r10 = CPyLong_AsInt64(r9)
47+
r7 = r10
48+
keep_alive n
49+
L3:
50+
r11 = LibRTStrings_BytesWriter_truncate_internal(b, r7)
51+
r12 = LibRTStrings_BytesWriter_getvalue_internal(b)
52+
return r12
3653
def bytes_writer_len(b):
3754
b :: librt.strings.BytesWriter
3855
r0 :: short_int

0 commit comments

Comments
 (0)