From e83679f0c2d541d40f2ebad3fe56bcb3501137d7 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 17 Mar 2024 18:06:44 +0100 Subject: [PATCH 01/18] Support narrowing literals and enums using the in operator in combination with tuple expressions. The general idea is to transform expressions like (x is None) and (x in (1, 2)) and (x not in (3, 4)) into (x is None) and (x == 1 or x == 2) and (x != 3 and x != 4) This transformation circumvents the need to extend the (already complicated) narrowing logic further. --- mypy/checker.py | 68 ++++++++++++++++++ test-data/unit/check-isinstance.test | 2 +- test-data/unit/check-narrowing.test | 101 +++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 941dc06f1c71..2f7389c9f681 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4532,12 +4532,80 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.in_checked_function(): self.fail(message_registry.RETURN_VALUE_EXPECTED, s) + def _make_tupleexpr_with_literals_narrowable_by_using_in( + self, e: Expression + ) -> Optional[Expression]: + """ + Transform an expression like + + (x is None) and (x in (1, 2)) and (x not in (3, 4)) + + into + + (x is None) and (x == 1 or x == 2) and (x != 3 and x != 4) + + This transformation is supposed to enable narrowing literals and enums using the in + (and the not in) operator in combination with tuple expressions without the need to + implement additional narrowing logic. + """ + if isinstance(e, OpExpr): + l = self._make_tupleexprs_with_literals_narrowable_by_using_in(e.left) + assert l is not None + e.left = l + r = self._make_tupleexprs_with_literals_narrowable_by_using_in(e.right) + assert r is not None + e.right = r + return e + + if not ( + isinstance(e, ComparisonExpr) and + isinstance(left := e.operands[0], NameExpr) and + ((op_in := e.operators[0]) in ("in", "not in")) and + isinstance(tuple_ := e.operands[1], TupleExpr) + ): + return e + + op_eq, op_con = (["=="], "or") if (op_in == "in") else (["!="], "and") + line = e.line + left_new = left + comparisons = [] + for right in reversed(tuple_.items): + comparison = ComparisonExpr(op_eq, [left_new, right]) + comparison.line = line + comparisons.append(comparison) + left_new = NameExpr(left.name) + left_new.node = left.node + left_new.line = left.line + if (nmb := len(comparisons)) == 0: + if op_in == "in": + if self.should_report_unreachable_issues(): + e.line += 1 + self.msg.unreachable_statement(e) + e.line -= 1 + return None + e = NameExpr("True") + e.fullname = "builtins.True" + e.line = line + return e + if nmb == 1: + return comparisons[0] + e = OpExpr(op_con, comparisons[1], comparisons[0]) + for comparison in comparisons[2:]: + e = OpExpr(op_con, comparison, e) + e.line = line + return e + def visit_if_stmt(self, s: IfStmt) -> None: """Type check an if statement.""" # This frame records the knowledge from previous if/elif clauses not being taken. # Fall-through to the original frame is handled explicitly in each block. with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0): for e, b in zip(s.expr, s.body): + + if (et := self._make_tupleexpr_with_literals_narrowable_by_using_in(e)) is None: + continue + e = et + t = get_proper_type(self.expr_checker.accept(e)) if isinstance(t, DeletedType): diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index b7ee38b69d00..254345190881 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -1969,7 +1969,7 @@ class C(A): pass y: Optional[B] if y in (B(), C()): - reveal_type(y) # N: Revealed type is "__main__.B" + reveal_type(y) # N: Revealed type is "Union[__main__.B, None]" else: reveal_type(y) # N: Revealed type is "Union[__main__.B, None]" [builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 4d117687554e..529602b6ab6c 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2089,3 +2089,104 @@ if isinstance(x, (Z, NoneType)): # E: Subclass of "X" and "Z" cannot exist: "Z" reveal_type(x) # E: Statement is unreachable [builtins fixtures/isinstance.pyi] + +[case testNarrowLiteralsInTupleExpression] +# flags: --warn-unreachable + +from typing import Optional +from typing_extensions import Literal + +x: int + +def f(v: Optional[Literal[1, 2, 3, 4]]) -> None: + if v in (0, 1, 2): + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" + elif v in (1,): + reveal_type(v) # E: Statement is unreachable + elif v is None or v in (3, x): + reveal_type(v) # N: Revealed type is "Union[Literal[3], Literal[4], None]" + elif v in (): + reveal_type(v) # E: Statement is unreachable + else: + reveal_type(v) # N: Revealed type is "Literal[4]" + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], None]" +[builtins fixtures/primitives.pyi] + +[case testNarrowLiteralsNotInTupleExpression] +# flags: --warn-unreachable + +from typing import Optional +from typing_extensions import Literal + +x: int + +def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None: + if v not in (0, 1, 2, 3): + reveal_type(v) # N: Revealed type is "Union[Literal[4], Literal[5], None]" + elif v not in (1, 2, 3, 4): # E: Right operand of "and" is never evaluated + reveal_type(v) # E: Statement is unreachable + elif v is not None and v not in (3,): + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" + elif v not in (x, 3): + reveal_type(v) # E: Statement is unreachable + else: + reveal_type(v) # N: Revealed type is "Literal[3]" + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], Literal[5], None]" +[builtins fixtures/primitives.pyi] + +[case testNarrowEnumsInTupleExpression] +from enum import Enum +from typing import Final + +class E(Enum): + A = 1 + B = 2 + C = 3 + D = 4 + +A: Final = E.A +C: Final = E.C + +def f(v: E) -> None: + reveal_type(v) # N: Revealed type is "__main__.E" + if v in (A, E.B): + reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" + elif v in (E.A,): + reveal_type(v) + elif v in (C,): + reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" + elif v in (): + reveal_type(v) + else: + reveal_type(v) # N: Revealed type is "Literal[__main__.E.D]" + reveal_type(v) # N: Revealed type is "__main__.E" +[builtins fixtures/primitives.pyi] + +[case testNarrowEnumsNotInTupleExpression] +from enum import Enum +from typing import Final + +class E(Enum): + A = 1 + B = 2 + C = 3 + D = 4 + E = 5 + +A: Final = E.A +C: Final = E.C + +def f(v: E) -> None: + reveal_type(v) # N: Revealed type is "__main__.E" + if v not in (A, E.B, E.C): + reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.D], Literal[__main__.E.E]]" + elif v not in (E.A, E.B, E.C, E.C): + reveal_type(v) + elif v not in (C,): + reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" + elif v not in (): + reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" + else: + reveal_type(v) + reveal_type(v) # N: Revealed type is "__main__.E" +[builtins fixtures/primitives.pyi] From a0d1db32069fe92b8118b58b198fbe662968c00a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Mar 2024 17:16:07 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 2f7389c9f681..2be8241a319e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4558,10 +4558,10 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in( return e if not ( - isinstance(e, ComparisonExpr) and - isinstance(left := e.operands[0], NameExpr) and - ((op_in := e.operators[0]) in ("in", "not in")) and - isinstance(tuple_ := e.operands[1], TupleExpr) + isinstance(e, ComparisonExpr) + and isinstance(left := e.operands[0], NameExpr) + and ((op_in := e.operators[0]) in ("in", "not in")) + and isinstance(tuple_ := e.operands[1], TupleExpr) ): return e From 2bff3c296c314984b30462166b285d8837e7ac47 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 17 Mar 2024 18:43:25 +0100 Subject: [PATCH 03/18] fix --- mypy/checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 2f7389c9f681..ee853f0542cd 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4549,10 +4549,10 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in( implement additional narrowing logic. """ if isinstance(e, OpExpr): - l = self._make_tupleexprs_with_literals_narrowable_by_using_in(e.left) + l = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.left) assert l is not None e.left = l - r = self._make_tupleexprs_with_literals_narrowable_by_using_in(e.right) + r = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.right) assert r is not None e.right = r return e From 2fa954a1f28f59b6d633eddc264690a68766d121 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 17 Mar 2024 19:24:29 +0100 Subject: [PATCH 04/18] Optional[Expression] -> Expression | None --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 120766a54933..87b5be0c4459 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4534,7 +4534,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None: def _make_tupleexpr_with_literals_narrowable_by_using_in( self, e: Expression - ) -> Optional[Expression]: + ) -> Expression | None: """ Transform an expression like From db7b96946dfdaa7f934fd1047a1a4cafcb022c08 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 17 Mar 2024 19:29:11 +0100 Subject: [PATCH 05/18] Add __eq__ to object in tuple.pyi --- test-data/unit/fixtures/tuple.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/test-data/unit/fixtures/tuple.pyi b/test-data/unit/fixtures/tuple.pyi index eb89de8c86ef..9e42fa0b4f44 100644 --- a/test-data/unit/fixtures/tuple.pyi +++ b/test-data/unit/fixtures/tuple.pyi @@ -8,6 +8,7 @@ Tco = TypeVar('Tco', covariant=True) class object: def __init__(self) -> None: pass + def __eq__(self, other: object) -> bool: pass class type: def __init__(self, *a: object) -> None: pass From 23bfd4e0981642500e6d6568f06ff4a9661c6328 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 18 Mar 2024 06:53:43 +0100 Subject: [PATCH 06/18] simplification: avoid returning None --- mypy/checker.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 87b5be0c4459..daac0dd85b74 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4532,9 +4532,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.in_checked_function(): self.fail(message_registry.RETURN_VALUE_EXPECTED, s) - def _make_tupleexpr_with_literals_narrowable_by_using_in( - self, e: Expression - ) -> Expression | None: + def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> Expression: """ Transform an expression like @@ -4549,12 +4547,8 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in( implement additional narrowing logic. """ if isinstance(e, OpExpr): - l = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.left) - assert l is not None - e.left = l - r = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.right) - assert r is not None - e.right = r + e.left = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.left) + e.right = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.right) return e if not ( @@ -4578,11 +4572,10 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in( left_new.line = left.line if (nmb := len(comparisons)) == 0: if op_in == "in": - if self.should_report_unreachable_issues(): - e.line += 1 - self.msg.unreachable_statement(e) - e.line -= 1 - return None + e = NameExpr("False") + e.fullname = "builtins.False" + e.line = line + return e e = NameExpr("True") e.fullname = "builtins.True" e.line = line @@ -4602,9 +4595,7 @@ def visit_if_stmt(self, s: IfStmt) -> None: with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0): for e, b in zip(s.expr, s.body): - if (et := self._make_tupleexpr_with_literals_narrowable_by_using_in(e)) is None: - continue - e = et + e = self._make_tupleexpr_with_literals_narrowable_by_using_in(e) t = get_proper_type(self.expr_checker.accept(e)) From 207c56eb3cd7732fbaa82e2f7e10736bf127a608 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Tue, 19 Mar 2024 23:05:48 +0100 Subject: [PATCH 07/18] replace the critical `in` comparisons for testing --- mypy/checkexpr.py | 2 +- mypy/find_sources.py | 2 +- mypy/reachability.py | 4 ++-- mypy/semanal.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 37a90ce55b9e..4fd8c3f04bdf 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3575,7 +3575,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. - if not w.has_new_errors() and operator in ("==", "!="): + if not w.has_new_errors() and (operator == "==" or operator == "!="): right_type = self.accept(right) if self.dangerous_comparison(left_type, right_type): # Show the most specific literal types possible diff --git a/mypy/find_sources.py b/mypy/find_sources.py index 3565fc4609cd..1921ea2992d1 100644 --- a/mypy/find_sources.py +++ b/mypy/find_sources.py @@ -106,7 +106,7 @@ def find_sources_in_dir(self, path: str) -> list[BuildSource]: names = sorted(self.fscache.listdir(path), key=keyfunc) for name in names: # Skip certain names altogether - if name in ("__pycache__", "site-packages", "node_modules") or name.startswith("."): + if (name == "__pycache__" or name == "site-packages" or name == "node_modules") or name.startswith("."): continue subpath = os.path.join(path, name) diff --git a/mypy/reachability.py b/mypy/reachability.py index a25b9dff4581..81d673acf54a 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -130,8 +130,8 @@ def infer_condition_value(expr: Expression, options: Options) -> int: name = expr.name elif isinstance(expr, OpExpr) and expr.op in ("and", "or"): left = infer_condition_value(expr.left, options) - if (left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or ( - left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or" + if ((left == ALWAYS_TRUE or left == MYPY_TRUE) and expr.op == "and") or ( + (left == ALWAYS_FALSE or left == MYPY_FALSE) and expr.op == "or" ): # Either `True and ` or `False or `: the result will # always be the right-hand-side. diff --git a/mypy/semanal.py b/mypy/semanal.py index 6bf02382a036..c5e26b0ae358 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -5301,13 +5301,13 @@ def visit_op_expr(self, expr: OpExpr) -> None: if expr.op in ("and", "or"): inferred = infer_condition_value(expr.left, self.options) - if (inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "and") or ( - inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "or" + if ((inferred == ALWAYS_FALSE or inferred == MYPY_FALSE) and expr.op == "and") or ( + (inferred == ALWAYS_TRUE or inferred == MYPY_TRUE) and expr.op == "or" ): expr.right_unreachable = True return - elif (inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or ( - inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or" + elif ((inferred == ALWAYS_TRUE or inferred == MYPY_TRUE) and expr.op == "and") or ( + (inferred == ALWAYS_FALSE or inferred == MYPY_FALSE) and expr.op == "or" ): expr.right_always = True From dea2614f4f7023280884e0366a82d81a85e92481 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Mar 2024 22:07:30 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/find_sources.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/find_sources.py b/mypy/find_sources.py index 1921ea2992d1..27ff3ca97106 100644 --- a/mypy/find_sources.py +++ b/mypy/find_sources.py @@ -106,7 +106,9 @@ def find_sources_in_dir(self, path: str) -> list[BuildSource]: names = sorted(self.fscache.listdir(path), key=keyfunc) for name in names: # Skip certain names altogether - if (name == "__pycache__" or name == "site-packages" or name == "node_modules") or name.startswith("."): + if ( + name == "__pycache__" or name == "site-packages" or name == "node_modules" + ) or name.startswith("."): continue subpath = os.path.join(path, name) From 214f51a79329415a5917805d467364d0641fc9ea Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Tue, 19 Mar 2024 23:26:05 +0100 Subject: [PATCH 09/18] Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit dea2614f4f7023280884e0366a82d81a85e92481. --- mypy/find_sources.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mypy/find_sources.py b/mypy/find_sources.py index 27ff3ca97106..1921ea2992d1 100644 --- a/mypy/find_sources.py +++ b/mypy/find_sources.py @@ -106,9 +106,7 @@ def find_sources_in_dir(self, path: str) -> list[BuildSource]: names = sorted(self.fscache.listdir(path), key=keyfunc) for name in names: # Skip certain names altogether - if ( - name == "__pycache__" or name == "site-packages" or name == "node_modules" - ) or name.startswith("."): + if (name == "__pycache__" or name == "site-packages" or name == "node_modules") or name.startswith("."): continue subpath = os.path.join(path, name) From a495fee4174b8be8c22765e5842e224fa8fcdfe3 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Tue, 19 Mar 2024 23:26:18 +0100 Subject: [PATCH 10/18] Revert "replace the critical `in` comparisons for testing" This reverts commit 207c56eb3cd7732fbaa82e2f7e10736bf127a608. --- mypy/checkexpr.py | 2 +- mypy/find_sources.py | 2 +- mypy/reachability.py | 4 ++-- mypy/semanal.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4fd8c3f04bdf..37a90ce55b9e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3575,7 +3575,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. - if not w.has_new_errors() and (operator == "==" or operator == "!="): + if not w.has_new_errors() and operator in ("==", "!="): right_type = self.accept(right) if self.dangerous_comparison(left_type, right_type): # Show the most specific literal types possible diff --git a/mypy/find_sources.py b/mypy/find_sources.py index 1921ea2992d1..3565fc4609cd 100644 --- a/mypy/find_sources.py +++ b/mypy/find_sources.py @@ -106,7 +106,7 @@ def find_sources_in_dir(self, path: str) -> list[BuildSource]: names = sorted(self.fscache.listdir(path), key=keyfunc) for name in names: # Skip certain names altogether - if (name == "__pycache__" or name == "site-packages" or name == "node_modules") or name.startswith("."): + if name in ("__pycache__", "site-packages", "node_modules") or name.startswith("."): continue subpath = os.path.join(path, name) diff --git a/mypy/reachability.py b/mypy/reachability.py index 81d673acf54a..a25b9dff4581 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -130,8 +130,8 @@ def infer_condition_value(expr: Expression, options: Options) -> int: name = expr.name elif isinstance(expr, OpExpr) and expr.op in ("and", "or"): left = infer_condition_value(expr.left, options) - if ((left == ALWAYS_TRUE or left == MYPY_TRUE) and expr.op == "and") or ( - (left == ALWAYS_FALSE or left == MYPY_FALSE) and expr.op == "or" + if (left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or ( + left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or" ): # Either `True and ` or `False or `: the result will # always be the right-hand-side. diff --git a/mypy/semanal.py b/mypy/semanal.py index c5e26b0ae358..6bf02382a036 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -5301,13 +5301,13 @@ def visit_op_expr(self, expr: OpExpr) -> None: if expr.op in ("and", "or"): inferred = infer_condition_value(expr.left, self.options) - if ((inferred == ALWAYS_FALSE or inferred == MYPY_FALSE) and expr.op == "and") or ( - (inferred == ALWAYS_TRUE or inferred == MYPY_TRUE) and expr.op == "or" + if (inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "and") or ( + inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "or" ): expr.right_unreachable = True return - elif ((inferred == ALWAYS_TRUE or inferred == MYPY_TRUE) and expr.op == "and") or ( - (inferred == ALWAYS_FALSE or inferred == MYPY_FALSE) and expr.op == "or" + elif (inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or ( + inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or" ): expr.right_always = True From 8785fe9b157d819d79ccfb15639127f18c7f201b Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Thu, 21 Mar 2024 06:54:14 +0100 Subject: [PATCH 11/18] fix mypyc crash --- mypy/checker.py | 5 ++--- mypyc/test-data/run-tuples.test | 9 +++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index daac0dd85b74..47ea92b10a33 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5,6 +5,7 @@ import itertools from collections import defaultdict from contextlib import ExitStack, contextmanager +from copy import copy from typing import ( AbstractSet, Callable, @@ -4567,9 +4568,7 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> comparison = ComparisonExpr(op_eq, [left_new, right]) comparison.line = line comparisons.append(comparison) - left_new = NameExpr(left.name) - left_new.node = left.node - left_new.line = left.line + left_new = copy(left) if (nmb := len(comparisons)) == 0: if op_in == "in": e = NameExpr("False") diff --git a/mypyc/test-data/run-tuples.test b/mypyc/test-data/run-tuples.test index 0851c15e57fd..357081cdf4c5 100644 --- a/mypyc/test-data/run-tuples.test +++ b/mypyc/test-data/run-tuples.test @@ -256,3 +256,12 @@ TUPLE: Final[Tuple[str, ...]] = ('x', 'y') def test_final_boxed_tuple() -> None: t = TUPLE assert t == ('x', 'y') + +[case testTupleDoNotCrashOnTransformedInComparisons] +def f() -> None: + for n in ["x"]: + if n in ("x", "z") or n.startswith("y"): + print(n) +f() +[out] +x From 6af40aecf2ef1202227362de0cc9b122402d4ca9 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Thu, 21 Mar 2024 09:05:01 +0100 Subject: [PATCH 12/18] make NameExpr mypyc copyable --- mypy/nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index bb278d92392d..8de5221bc1c6 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1799,7 +1799,9 @@ class NameExpr(RefExpr): __match_args__ = ("name", "node") - def __init__(self, name: str) -> None: + def __init__(self, name: str = "?") -> None: + # The default name "?" aims to make NameExpr mypyc copyable. + # Always pass a proper name when manually calling NameExpr.__init__. super().__init__() self.name = name # Name referred to # Is this a l.h.s. of a special form assignment like typed dict or type variable? From e52eb2567d02b183cbeee96afb5dd02acebb80a8 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Thu, 21 Mar 2024 10:24:49 +0100 Subject: [PATCH 13/18] ignore star expressions --- mypy/checker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 47ea92b10a33..a6c6439e626b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4565,6 +4565,8 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> left_new = left comparisons = [] for right in reversed(tuple_.items): + if isinstance(right, StarExpr): + return e comparison = ComparisonExpr(op_eq, [left_new, right]) comparison.line = line comparisons.append(comparison) From 763c2658959f4b7ccb0c5bb8b76ebc8f6bc037aa Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Thu, 21 Mar 2024 19:50:29 +0100 Subject: [PATCH 14/18] update docs --- docs/source/literal_types.rst | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/source/literal_types.rst b/docs/source/literal_types.rst index 283bf7f9dba1..e3a180037472 100644 --- a/docs/source/literal_types.rst +++ b/docs/source/literal_types.rst @@ -328,7 +328,7 @@ perform an exhaustiveness check, you need to update your code to use an .. code-block:: python - from typing import Literal, NoReturn + from typing import Literal from typing_extensions import assert_never PossibleValues = Literal['one', 'two'] @@ -372,6 +372,19 @@ without a value: elif x == 'two': return False +For the sake of brevity, you can use the ``in`` operator in combination with tuple expressions +(tuples created "on the fly"): + +.. code-block:: python + + PossibleValues = Literal['one', 'two', 'three'] + + def validate(x: PossibleValues) -> bool: + if x == 'one': + return True + elif x in ('two', 'three'): + return False + Exhaustiveness checking is also supported for match statements (Python 3.10 and later): .. code-block:: python From 44d71ebb77a28027c75c1a0cde85ecc915834b08 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sat, 26 Oct 2024 06:52:10 +0200 Subject: [PATCH 15/18] Also support list expressions. --- docs/source/literal_types.rst | 6 +++--- mypy/checker.py | 4 ++-- test-data/unit/check-narrowing.test | 20 ++++++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/literal_types.rst b/docs/source/literal_types.rst index e3a180037472..15c0f5e8c8e5 100644 --- a/docs/source/literal_types.rst +++ b/docs/source/literal_types.rst @@ -372,15 +372,15 @@ without a value: elif x == 'two': return False -For the sake of brevity, you can use the ``in`` operator in combination with tuple expressions -(tuples created "on the fly"): +For the sake of brevity, you can use the ``in`` operator in combination with list or tuple +expressions (lists or tuples created "on the fly"): .. code-block:: python PossibleValues = Literal['one', 'two', 'three'] def validate(x: PossibleValues) -> bool: - if x == 'one': + if x in ['one']: return True elif x in ('two', 'three'): return False diff --git a/mypy/checker.py b/mypy/checker.py index a6c6439e626b..756579d8803e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4556,7 +4556,7 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> isinstance(e, ComparisonExpr) and isinstance(left := e.operands[0], NameExpr) and ((op_in := e.operators[0]) in ("in", "not in")) - and isinstance(tuple_ := e.operands[1], TupleExpr) + and isinstance(litu := e.operands[1], (ListExpr, TupleExpr)) ): return e @@ -4564,7 +4564,7 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> line = e.line left_new = left comparisons = [] - for right in reversed(tuple_.items): + for right in reversed(litu.items): if isinstance(right, StarExpr): return e comparison = ComparisonExpr(op_eq, [left_new, right]) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 529602b6ab6c..517e9a8b08ec 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2090,7 +2090,7 @@ if isinstance(x, (Z, NoneType)): # E: Subclass of "X" and "Z" cannot exist: "Z" [builtins fixtures/isinstance.pyi] -[case testNarrowLiteralsInTupleExpression] +[case testNarrowLiteralsInListOrTupleExpression] # flags: --warn-unreachable from typing import Optional @@ -2101,9 +2101,9 @@ x: int def f(v: Optional[Literal[1, 2, 3, 4]]) -> None: if v in (0, 1, 2): reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" - elif v in (1,): + elif v in [1]: reveal_type(v) # E: Statement is unreachable - elif v is None or v in (3, x): + elif v is None or v in [3, x]: reveal_type(v) # N: Revealed type is "Union[Literal[3], Literal[4], None]" elif v in (): reveal_type(v) # E: Statement is unreachable @@ -2112,7 +2112,7 @@ def f(v: Optional[Literal[1, 2, 3, 4]]) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], None]" [builtins fixtures/primitives.pyi] -[case testNarrowLiteralsNotInTupleExpression] +[case testNarrowLiteralsNotInListOrTupleExpression] # flags: --warn-unreachable from typing import Optional @@ -2123,7 +2123,7 @@ x: int def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None: if v not in (0, 1, 2, 3): reveal_type(v) # N: Revealed type is "Union[Literal[4], Literal[5], None]" - elif v not in (1, 2, 3, 4): # E: Right operand of "and" is never evaluated + elif v not in [1, 2, 3, 4]: # E: Right operand of "and" is never evaluated reveal_type(v) # E: Statement is unreachable elif v is not None and v not in (3,): reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" @@ -2134,7 +2134,7 @@ def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], Literal[5], None]" [builtins fixtures/primitives.pyi] -[case testNarrowEnumsInTupleExpression] +[case testNarrowEnumsInListOrTupleExpression] from enum import Enum from typing import Final @@ -2151,7 +2151,7 @@ def f(v: E) -> None: reveal_type(v) # N: Revealed type is "__main__.E" if v in (A, E.B): reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" - elif v in (E.A,): + elif v in [E.A]: reveal_type(v) elif v in (C,): reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" @@ -2162,7 +2162,7 @@ def f(v: E) -> None: reveal_type(v) # N: Revealed type is "__main__.E" [builtins fixtures/primitives.pyi] -[case testNarrowEnumsNotInTupleExpression] +[case testNarrowEnumsNotInListOrTupleExpression] from enum import Enum from typing import Final @@ -2180,11 +2180,11 @@ def f(v: E) -> None: reveal_type(v) # N: Revealed type is "__main__.E" if v not in (A, E.B, E.C): reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.D], Literal[__main__.E.E]]" - elif v not in (E.A, E.B, E.C, E.C): + elif v not in [E.A, E.B, E.C, E.C]: reveal_type(v) elif v not in (C,): reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" - elif v not in (): + elif v not in []: reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" else: reveal_type(v) From a04126d821a74b12ab5578991a170f67c8cf8500 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 27 Oct 2024 22:04:55 +0100 Subject: [PATCH 16/18] Also support set expressions. --- docs/source/literal_types.rst | 4 ++-- mypy/checker.py | 13 +++++++------ test-data/unit/check-narrowing.test | 16 ++++++++-------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/source/literal_types.rst b/docs/source/literal_types.rst index a47e5db8678d..8adbccddfb6c 100644 --- a/docs/source/literal_types.rst +++ b/docs/source/literal_types.rst @@ -368,8 +368,8 @@ without a value: elif x == 'two': return False -For the sake of brevity, you can use the ``in`` operator in combination with list or tuple -expressions (lists or tuples created "on the fly"): +For the sake of brevity, you can use the ``in`` operator in combination with +list, set, or tuple expressions (lists, sets, or tuples created "on the fly"): .. code-block:: python diff --git a/mypy/checker.py b/mypy/checker.py index b594e113d94b..c794e8e0c47b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -117,6 +117,7 @@ RaiseStmt, RefExpr, ReturnStmt, + SetExpr, StarExpr, Statement, StrExpr, @@ -4685,11 +4686,11 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.in_checked_function(): self.fail(message_registry.RETURN_VALUE_EXPECTED, s) - def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> Expression: + def _transform_sequence_expressions_for_narrowing_with_in(self, e: Expression) -> Expression: """ Transform an expression like - (x is None) and (x in (1, 2)) and (x not in (3, 4)) + (x is None) and (x in (1, 2)) and (x not in [3, 4]) into @@ -4700,15 +4701,15 @@ def _make_tupleexpr_with_literals_narrowable_by_using_in(self, e: Expression) -> implement additional narrowing logic. """ if isinstance(e, OpExpr): - e.left = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.left) - e.right = self._make_tupleexpr_with_literals_narrowable_by_using_in(e.right) + e.left = self._transform_sequence_expressions_for_narrowing_with_in(e.left) + e.right = self._transform_sequence_expressions_for_narrowing_with_in(e.right) return e if not ( isinstance(e, ComparisonExpr) and isinstance(left := e.operands[0], NameExpr) and ((op_in := e.operators[0]) in ("in", "not in")) - and isinstance(litu := e.operands[1], (ListExpr, TupleExpr)) + and isinstance(litu := e.operands[1], (ListExpr, SetExpr, TupleExpr)) ): return e @@ -4748,7 +4749,7 @@ def visit_if_stmt(self, s: IfStmt) -> None: with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0): for e, b in zip(s.expr, s.body): - e = self._make_tupleexpr_with_literals_narrowable_by_using_in(e) + e = self._transform_sequence_expressions_for_narrowing_with_in(e) t = get_proper_type(self.expr_checker.accept(e)) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index f3d5925aa216..9089d410f739 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2334,7 +2334,7 @@ def f(x: C) -> None: f(C(5)) [builtins fixtures/primitives.pyi] -[case testNarrowLiteralsInListOrTupleExpression] +[case testNarrowLiteralsInListOrSetOrTupleExpression] # flags: --warn-unreachable from typing import Optional @@ -2347,7 +2347,7 @@ def f(v: Optional[Literal[1, 2, 3, 4]]) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" elif v in [1]: reveal_type(v) # E: Statement is unreachable - elif v is None or v in [3, x]: + elif v is None or v in {3, x}: reveal_type(v) # N: Revealed type is "Union[Literal[3], Literal[4], None]" elif v in (): reveal_type(v) # E: Statement is unreachable @@ -2356,7 +2356,7 @@ def f(v: Optional[Literal[1, 2, 3, 4]]) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], None]" [builtins fixtures/primitives.pyi] -[case testNarrowLiteralsNotInListOrTupleExpression] +[case testNarrowLiteralsNotInListOrSetOrTupleExpression] # flags: --warn-unreachable from typing import Optional @@ -2365,7 +2365,7 @@ from typing_extensions import Literal x: int def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None: - if v not in (0, 1, 2, 3): + if v not in {0, 1, 2, 3}: reveal_type(v) # N: Revealed type is "Union[Literal[4], Literal[5], None]" elif v not in [1, 2, 3, 4]: # E: Right operand of "and" is never evaluated reveal_type(v) # E: Statement is unreachable @@ -2378,7 +2378,7 @@ def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], Literal[5], None]" [builtins fixtures/primitives.pyi] -[case testNarrowEnumsInListOrTupleExpression] +[case testNarrowEnumsInListOrSetOrTupleExpression] from enum import Enum from typing import Final @@ -2397,7 +2397,7 @@ def f(v: E) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" elif v in [E.A]: reveal_type(v) - elif v in (C,): + elif v in {C}: reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" elif v in (): reveal_type(v) @@ -2406,7 +2406,7 @@ def f(v: E) -> None: reveal_type(v) # N: Revealed type is "__main__.E" [builtins fixtures/primitives.pyi] -[case testNarrowEnumsNotInListOrTupleExpression] +[case testNarrowEnumsNotInListOrSetOrTupleExpression] from enum import Enum from typing import Final @@ -2426,7 +2426,7 @@ def f(v: E) -> None: reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.D], Literal[__main__.E.E]]" elif v not in [E.A, E.B, E.C, E.C]: reveal_type(v) - elif v not in (C,): + elif v not in {C}: reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" elif v not in []: reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" From 78dbff220e6a2b154110fc882a8abe4f57bf96fb Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Fri, 23 May 2025 21:41:05 +0200 Subject: [PATCH 17/18] Adjust the `testNarrowingOptionalEqualsNone` test case. A recent change that I just merged seems to interact positively with this PR: increase in consistency when narrowing away `None`. --- test-data/unit/check-narrowing.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 64673a9d57cb..74a55cee65f9 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1375,9 +1375,9 @@ else: if val in (None,): reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" else: - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "__main__.A" if val not in (None,): - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "__main__.A" else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" [builtins fixtures/primitives.pyi] From e4456516d2353fefd2b44c2e45bf4bd15fb2afe3 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Fri, 23 May 2025 21:48:50 +0200 Subject: [PATCH 18/18] update comment --- mypy/checker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3af19cb03e8b..c0332e3f29c1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4844,9 +4844,9 @@ def _transform_sequence_expressions_for_narrowing_with_in(self, e: Expression) - (x is None) and (x == 1 or x == 2) and (x != 3 and x != 4) - This transformation is supposed to enable narrowing literals and enums using the in - (and the not in) operator in combination with tuple expressions without the need to - implement additional narrowing logic. + This transformation is supposed to enable narrowing literals and enums using the + in (and the not in) operator in combination with tuple, list, and set expressions + without the need to implement additional narrowing logic. """ if isinstance(e, OpExpr): e.left = self._transform_sequence_expressions_for_narrowing_with_in(e.left)