diff --git a/docs/source/literal_types.rst b/docs/source/literal_types.rst index 877ab5de9087..8adbccddfb6c 100644 --- a/docs/source/literal_types.rst +++ b/docs/source/literal_types.rst @@ -324,7 +324,7 @@ perform an exhaustiveness check, you need to update your code to use an .. code-block:: python - from typing import Literal, NoReturn + from typing import Literal from typing_extensions import assert_never PossibleValues = Literal['one', 'two'] @@ -368,6 +368,19 @@ without a value: elif x == 'two': return False +For the sake of brevity, you can use the ``in`` operator in combination with +list, set, or tuple expressions (lists, sets, or tuples created "on the fly"): + +.. code-block:: python + + PossibleValues = Literal['one', 'two', 'three'] + + def validate(x: PossibleValues) -> bool: + if x in ['one']: + return True + elif x in ('two', 'three'): + return False + Exhaustiveness checking is also supported for match statements (Python 3.10 and later): .. code-block:: python diff --git a/mypy/checker.py b/mypy/checker.py index aceb0291926a..c0332e3f29c1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6,6 +6,7 @@ from collections import defaultdict from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet from contextlib import ExitStack, contextmanager +from copy import copy from typing import Callable, Final, Generic, NamedTuple, Optional, TypeVar, Union, cast, overload from typing_extensions import TypeAlias as _TypeAlias, TypeGuard @@ -104,6 +105,7 @@ RaiseStmt, RefExpr, ReturnStmt, + SetExpr, StarExpr, Statement, StrExpr, @@ -4832,12 +4834,71 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.in_checked_function(): self.fail(message_registry.RETURN_VALUE_EXPECTED, s) + def _transform_sequence_expressions_for_narrowing_with_in(self, e: Expression) -> Expression: + """ + Transform an expression like + + (x is None) and (x in (1, 2)) and (x not in [3, 4]) + + into + + (x is None) and (x == 1 or x == 2) and (x != 3 and x != 4) + + This transformation is supposed to enable narrowing literals and enums using the + in (and the not in) operator in combination with tuple, list, and set expressions + without the need to implement additional narrowing logic. + """ + if isinstance(e, OpExpr): + e.left = self._transform_sequence_expressions_for_narrowing_with_in(e.left) + e.right = self._transform_sequence_expressions_for_narrowing_with_in(e.right) + return e + + if not ( + isinstance(e, ComparisonExpr) + and isinstance(left := e.operands[0], NameExpr) + and ((op_in := e.operators[0]) in ("in", "not in")) + and isinstance(litu := e.operands[1], (ListExpr, SetExpr, TupleExpr)) + ): + return e + + op_eq, op_con = (["=="], "or") if (op_in == "in") else (["!="], "and") + line = e.line + left_new = left + comparisons = [] + for right in reversed(litu.items): + if isinstance(right, StarExpr): + return e + comparison = ComparisonExpr(op_eq, [left_new, right]) + comparison.line = line + comparisons.append(comparison) + left_new = copy(left) + if (nmb := len(comparisons)) == 0: + if op_in == "in": + e = NameExpr("False") + e.fullname = "builtins.False" + e.line = line + return e + e = NameExpr("True") + e.fullname = "builtins.True" + e.line = line + return e + if nmb == 1: + return comparisons[0] + e = OpExpr(op_con, comparisons[1], comparisons[0]) + for comparison in comparisons[2:]: + e = OpExpr(op_con, comparison, e) + e.line = line + return e + def visit_if_stmt(self, s: IfStmt) -> None: """Type check an if statement.""" # This frame records the knowledge from previous if/elif clauses not being taken. # Fall-through to the original frame is handled explicitly in each block. with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0): for e, b in zip(s.expr, s.body): + + e = self._transform_sequence_expressions_for_narrowing_with_in(e) + t = get_proper_type(self.expr_checker.accept(e)) if isinstance(t, DeletedType): diff --git a/mypy/nodes.py b/mypy/nodes.py index 584e56667944..57da93c18ee3 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1910,7 +1910,9 @@ class NameExpr(RefExpr): __match_args__ = ("name", "node") - def __init__(self, name: str) -> None: + def __init__(self, name: str = "?") -> None: + # The default name "?" aims to make NameExpr mypyc copyable. + # Always pass a proper name when manually calling NameExpr.__init__. super().__init__() self.name = name # Name referred to # Is this a l.h.s. of a special form assignment like typed dict or type variable? diff --git a/mypyc/test-data/run-tuples.test b/mypyc/test-data/run-tuples.test index 1437eaef2aa5..0c6590ca91a8 100644 --- a/mypyc/test-data/run-tuples.test +++ b/mypyc/test-data/run-tuples.test @@ -278,3 +278,12 @@ def test_multiply() -> None: assert (1,) * 3 == res assert 3 * (1,) == res assert multiply((1,), 3) == res + +[case testTupleDoNotCrashOnTransformedInComparisons] +def f() -> None: + for n in ["x"]: + if n in ("x", "z") or n.startswith("y"): + print(n) +f() +[out] +x diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 058db1ea8197..ebffdc5b543f 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2004,7 +2004,7 @@ class C(A): pass y: Optional[B] if y in (B(), C()): - reveal_type(y) # N: Revealed type is "__main__.B" + reveal_type(y) # N: Revealed type is "Union[__main__.B, None]" else: reveal_type(y) # N: Revealed type is "Union[__main__.B, None]" [builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index dc2cfd46d9ad..74a55cee65f9 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1375,9 +1375,9 @@ else: if val in (None,): reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" else: - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "__main__.A" if val not in (None,): - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "__main__.A" else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" [builtins fixtures/primitives.pyi] @@ -2313,6 +2313,107 @@ def f(x: C) -> None: f(C(5)) [builtins fixtures/primitives.pyi] +[case testNarrowLiteralsInListOrSetOrTupleExpression] +# flags: --warn-unreachable + +from typing import Optional +from typing_extensions import Literal + +x: int + +def f(v: Optional[Literal[1, 2, 3, 4]]) -> None: + if v in (0, 1, 2): + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" + elif v in [1]: + reveal_type(v) # E: Statement is unreachable + elif v is None or v in {3, x}: + reveal_type(v) # N: Revealed type is "Union[Literal[3], Literal[4], None]" + elif v in (): + reveal_type(v) # E: Statement is unreachable + else: + reveal_type(v) # N: Revealed type is "Literal[4]" + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], None]" +[builtins fixtures/primitives.pyi] + +[case testNarrowLiteralsNotInListOrSetOrTupleExpression] +# flags: --warn-unreachable + +from typing import Optional +from typing_extensions import Literal + +x: int + +def f(v: Optional[Literal[1, 2, 3, 4, 5]]) -> None: + if v not in {0, 1, 2, 3}: + reveal_type(v) # N: Revealed type is "Union[Literal[4], Literal[5], None]" + elif v not in [1, 2, 3, 4]: # E: Right operand of "and" is never evaluated + reveal_type(v) # E: Statement is unreachable + elif v is not None and v not in (3,): + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2]]" + elif v not in (x, 3): + reveal_type(v) # E: Statement is unreachable + else: + reveal_type(v) # N: Revealed type is "Literal[3]" + reveal_type(v) # N: Revealed type is "Union[Literal[1], Literal[2], Literal[3], Literal[4], Literal[5], None]" +[builtins fixtures/primitives.pyi] + +[case testNarrowEnumsInListOrSetOrTupleExpression] +from enum import Enum +from typing import Final + +class E(Enum): + A = 1 + B = 2 + C = 3 + D = 4 + +A: Final = E.A +C: Final = E.C + +def f(v: E) -> None: + reveal_type(v) # N: Revealed type is "__main__.E" + if v in (A, E.B): + reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" + elif v in [E.A]: + reveal_type(v) + elif v in {C}: + reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" + elif v in (): + reveal_type(v) + else: + reveal_type(v) # N: Revealed type is "Literal[__main__.E.D]" + reveal_type(v) # N: Revealed type is "__main__.E" +[builtins fixtures/primitives.pyi] + +[case testNarrowEnumsNotInListOrSetOrTupleExpression] +from enum import Enum +from typing import Final + +class E(Enum): + A = 1 + B = 2 + C = 3 + D = 4 + E = 5 + +A: Final = E.A +C: Final = E.C + +def f(v: E) -> None: + reveal_type(v) # N: Revealed type is "__main__.E" + if v not in (A, E.B, E.C): + reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.D], Literal[__main__.E.E]]" + elif v not in [E.A, E.B, E.C, E.C]: + reveal_type(v) + elif v not in {C}: + reveal_type(v) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" + elif v not in []: + reveal_type(v) # N: Revealed type is "Literal[__main__.E.C]" + else: + reveal_type(v) + reveal_type(v) # N: Revealed type is "__main__.E" +[builtins fixtures/primitives.pyi] + [case testNarrowingTypeVarNone] # flags: --warn-unreachable diff --git a/test-data/unit/fixtures/tuple.pyi b/test-data/unit/fixtures/tuple.pyi index d01cd0034d26..5bc4b02da7ea 100644 --- a/test-data/unit/fixtures/tuple.pyi +++ b/test-data/unit/fixtures/tuple.pyi @@ -8,7 +8,8 @@ _Tco = TypeVar('_Tco', covariant=True) class object: def __init__(self) -> None: pass - def __new__(cls) -> Self: ... + def __new__(cls) -> Self: pass + def __eq__(self, other: object) -> bool: pass class type: def __init__(self, *a: object) -> None: pass