diff --git a/mypy/checker.py b/mypy/checker.py index 07f5c520de95..28a10c9d4f9c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools import itertools from collections import defaultdict from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet @@ -47,7 +48,12 @@ from mypy.expandtype import expand_type from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash from mypy.maptype import map_instance_to_supertype -from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types +from mypy.meet import ( + is_overlapping_erased_types, + is_overlapping_types, + meet_types, + narrow_declared_type, +) from mypy.message_registry import ErrorMessage from mypy.messages import ( SUGGESTED_TEST_FIXTURES, @@ -6237,65 +6243,89 @@ def is_type_call(expr: CallExpr) -> bool: # exprs that are being passed into type exprs_in_type_calls: list[Expression] = [] - # type that is being compared to type(expr) - type_being_compared: list[TypeRange] | None = None - # whether the type being compared to is final + # all the types that an expression will have if the overall expression is truthy + target_types: list[list[TypeRange]] = [] + # only a single type can be used when passed directly (eg "str") + fixed_type: Type | None = None + # is this single type final? is_final = False + def update_fixed_type(new_fixed_type: Type, new_is_final: bool) -> bool: + """Returns if the update succeeds""" + nonlocal fixed_type, is_final + if update := (fixed_type is None or (is_same_type(new_fixed_type, fixed_type))): + fixed_type = new_fixed_type + is_final = new_is_final + return update + for index in expr_indices: expr = node.operands[index] + proper_type = get_proper_type(self.lookup_type(expr)) if isinstance(expr, CallExpr) and is_type_call(expr): - exprs_in_type_calls.append(expr.args[0]) - else: - current_type = self.get_isinstance_type(expr) - if current_type is None: - continue - if type_being_compared is not None: - # It doesn't really make sense to have several types being - # compared to the output of type (like type(x) == int == str) - # because whether that's true is solely dependent on what the - # types being compared are, so we don't try to narrow types any - # further because we can't really get any information about the - # type of x from that check - return {}, {} - else: - if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo): - is_final = expr.node.is_final - type_being_compared = current_type + arg = expr.args[0] + exprs_in_type_calls.append(arg) + elif ( + isinstance(expr, OpExpr) + or isinstance(proper_type, TupleType) + or is_named_instance(proper_type, "builtins.tuple") + ): + # not valid for type comparisons, but allowed for isinstance checks + fixed_type = UninhabitedType() + continue + + type_range = self.get_isinstance_type(expr) + if type_range is not None: + target_types.append(type_range) + if ( + isinstance(expr, RefExpr) + and isinstance(expr.node, TypeInfo) + and len(type_range) == 1 + ): + if not update_fixed_type( + Instance( + expr.node, + [AnyType(TypeOfAny.special_form)] * len(expr.node.defn.type_vars), + ), + expr.node.is_final, + ): + return None, {} if not exprs_in_type_calls: return {}, {} - if_maps: list[TypeMap] = [] - else_maps: list[TypeMap] = [] + if_maps = [] + else_maps = [] for expr in exprs_in_type_calls: - current_if_type, current_else_type = self.conditional_types_with_intersection( - self.lookup_type(expr), type_being_compared, expr - ) - current_if_map, current_else_map = conditional_types_to_typemaps( - expr, current_if_type, current_else_type - ) - if_maps.append(current_if_map) - else_maps.append(current_else_map) + expr_type: Type = get_proper_type(self.lookup_type(expr)) + for type_range in target_types: + restriction, _ = self.conditional_types_with_intersection( + expr_type, type_range, expr + ) + if restriction is not None: + narrowed_type = get_proper_type(narrow_declared_type(expr_type, restriction)) + # Cannot be guaranteed that this is unreachable, so use fallback type. + if isinstance(narrowed_type, UninhabitedType): + expr_type = restriction + else: + expr_type = narrowed_type + _, else_map = conditional_types_to_typemaps( + expr, + *self.conditional_types_with_intersection( + (self.lookup_type(expr)), (type_range), expr + ), + ) + else_maps.append(else_map) - def combine_maps(list_maps: list[TypeMap]) -> TypeMap: - """Combine all typemaps in list_maps into one typemap""" - if all(m is None for m in list_maps): - return None - result_map = {} - for d in list_maps: - if d is not None: - result_map.update(d) - return result_map - - if_map = combine_maps(if_maps) - # type(x) == T is only true when x has the same type as T, meaning - # that it can be false if x is an instance of a subclass of T. That means - # we can't do any narrowing in the else case unless T is final, in which - # case T can't be subclassed + if fixed_type and expr_type is not None: + expr_type = narrow_declared_type(expr_type, fixed_type) + + if_map, _ = conditional_types_to_typemaps(expr, expr_type, None) + if_maps.append(if_map) + + if_map = functools.reduce(and_conditional_maps, if_maps) if is_final: - else_map = combine_maps(else_maps) + else_map = functools.reduce(or_conditional_maps, else_maps) else: else_map = {} return if_map, else_map @@ -7039,7 +7069,6 @@ def refine_away_none_in_comparison( if_map, else_map = {}, {} if not non_optional_types or (len(non_optional_types) != len(chain_indices)): - # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be # convenient but is strictly not type-safe): for i in narrowable_operand_indices: @@ -7961,35 +7990,41 @@ def get_isinstance_type(self, expr: Expression) -> list[TypeRange] | None: return None return left + right all_types = get_proper_types(flatten_types(self.lookup_type(expr))) - types: list[TypeRange] = [] + type_ranges: list[TypeRange] = [] for typ in all_types: - if isinstance(typ, FunctionLike) and typ.is_type_obj(): - # If a type is generic, `isinstance` can only narrow its variables to Any. - any_parameterized = fill_typevars_with_any(typ.type_object()) - # Tuples may have unattended type variables among their items - if isinstance(any_parameterized, TupleType): - erased_type = erase_typevars(any_parameterized) - else: - erased_type = any_parameterized - types.append(TypeRange(erased_type, is_upper_bound=False)) - elif isinstance(typ, TypeType): - # Type[A] means "any type that is a subtype of A" rather than "precisely type A" - # we indicate this by setting is_upper_bound flag - is_upper_bound = True - if isinstance(typ.item, NoneType): - # except for Type[None], because "'NoneType' is not an acceptable base type" - is_upper_bound = False - types.append(TypeRange(typ.item, is_upper_bound=is_upper_bound)) - elif isinstance(typ, Instance) and typ.type.fullname == "builtins.type": - object_type = Instance(typ.type.mro[-1], []) - types.append(TypeRange(object_type, is_upper_bound=True)) - elif isinstance(typ, Instance) and typ.type.fullname == "types.UnionType" and typ.args: - types.append(TypeRange(UnionType(typ.args), is_upper_bound=False)) - elif isinstance(typ, AnyType): - types.append(TypeRange(typ, is_upper_bound=False)) - else: # we didn't see an actual type, but rather a variable with unknown value + type_range = self.isinstance_type_range(typ) + if type_range is None: return None - return types + type_ranges.append(type_range) + return type_ranges + + def isinstance_type_range(self, typ: ProperType) -> TypeRange | None: + if isinstance(typ, FunctionLike) and typ.is_type_obj(): + # If a type is generic, `isinstance` can only narrow its variables to Any. + any_parameterized = fill_typevars_with_any(typ.type_object()) + # Tuples may have unattended type variables among their items + if isinstance(any_parameterized, TupleType): + erased_type = erase_typevars(any_parameterized) + else: + erased_type = any_parameterized + return TypeRange(erased_type, is_upper_bound=False) + elif isinstance(typ, TypeType): + # Type[A] means "any type that is a subtype of A" rather than "precisely type A" + # we indicate this by setting is_upper_bound flag + is_upper_bound = True + if isinstance(typ.item, NoneType): + # except for Type[None], because "'NoneType' is not an acceptable base type" + is_upper_bound = False + return TypeRange(typ.item, is_upper_bound=is_upper_bound) + elif isinstance(typ, Instance) and typ.type.fullname == "builtins.type": + object_type = Instance(typ.type.mro[-1], []) + return TypeRange(object_type, is_upper_bound=True) + elif isinstance(typ, Instance) and typ.type.fullname == "types.UnionType" and typ.args: + return TypeRange(UnionType(typ.args), is_upper_bound=False) + elif isinstance(typ, AnyType): + return TypeRange(typ, is_upper_bound=False) + else: # we didn't see an actual type, but rather a variable with unknown value + return None def is_literal_enum(self, n: Expression) -> bool: """Returns true if this expression (with the given type context) is an Enum literal. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c02ff068560b..0cb9cbe29fc7 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,6 +8,7 @@ import mypy.applytype import mypy.constraints import mypy.typeops +from mypy.checker_shared import TypeRange from mypy.checker_state import checker_state from mypy.erasetype import erase_type from mypy.expandtype import ( @@ -255,6 +256,28 @@ def is_equivalent( ) +def is_same_type_ranges( + a: list[TypeRange], + b: list[TypeRange], + ignore_promotions: bool = True, + subtype_context: SubtypeContext | None = None, +) -> bool: + return len(a) == len(b) and all( + is_same_type_range(x, y, ignore_promotions, subtype_context) for x, y in zip(a, b) + ) + + +def is_same_type_range( + a: TypeRange, + b: TypeRange, + ignore_promotions: bool = True, + subtype_context: SubtypeContext | None = None, +) -> bool: + return a.is_upper_bound == b.is_upper_bound and is_same_type( + a.item, b.item, ignore_promotions, subtype_context + ) + + def is_same_type( a: Type, b: Type, ignore_promotions: bool = True, subtype_context: SubtypeContext | None = None ) -> bool: diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index acd4b588f98c..bb2fe38aa381 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2716,6 +2716,10 @@ if type(x) == type(y) == int: reveal_type(y) # N: Revealed type is "builtins.int" reveal_type(x) # N: Revealed type is "builtins.int" +z: Any +if int == type(z) == int: + reveal_type(z) # N: Revealed type is "builtins.int" + [case testTypeEqualsCheckUsingIs] from typing import Any @@ -2723,16 +2727,70 @@ y: Any if type(y) is int: reveal_type(y) # N: Revealed type is "builtins.int" +[case testTypeEqualsCheckUsingImplicitTypes] +from typing import Any + +x: str +y: Any +z: object +if type(y) is type(x): + reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(y) # N: Revealed type is "builtins.str" + +if type(x) is type(z): + reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(z) # N: Revealed type is "builtins.str" + +[case testTypeEqualsCheckUsingDifferentSpecializedTypes] +from collections import defaultdict + +x: defaultdict +y: dict +z: object +if type(x) is type(y) is type(z): + reveal_type(x) # N: Revealed type is "collections.defaultdict[Any, Any]" + reveal_type(y) # N: Revealed type is "collections.defaultdict[Any, Any]" + reveal_type(z) # N: Revealed type is "collections.defaultdict[Any, Any]" + +[case testUnionTypeEquality] +from typing import Any, reveal_type +# flags: --warn-unreachable + +x: Any = () +if type(x) == (int, str): + reveal_type(x) # E: Statement is unreachable + +[builtins fixtures/tuple.pyi] + +[case testTypeIntersectionWithConcreteTypes] +class X: x = 1 +class Y: y = 1 +class Z(X, Y): ... + +z = Z() +x: X = z +y: Y = z +if type(x) is type(y): + reveal_type(x) # N: Revealed type is "__main__." + reveal_type(y) # N: Revealed type is "__main__." + x.y + y.x + +if isinstance(x, type(y)) and isinstance(y, type(x)): + reveal_type(x) # N: Revealed type is "__main__." + reveal_type(y) # N: Revealed type is "__main__." + x.y + y.x + +[builtins fixtures/isinstance.pyi] + [case testTypeEqualsCheckUsingIsNonOverlapping] # flags: --warn-unreachable from typing import Union y: str -if type(y) is int: # E: Subclass of "str" and "int" cannot exist: would have incompatible method signatures - y # E: Statement is unreachable +if type(y) is int: + y else: reveal_type(y) # N: Revealed type is "builtins.str" -[builtins fixtures/isinstance.pyi] [case testTypeEqualsCheckUsingIsNonOverlappingChild-xfail] # flags: --warn-unreachable @@ -2761,12 +2819,13 @@ else: [case testTypeEqualsMultipleTypesShouldntNarrow] # make sure we don't do any narrowing if there are multiple types being compared +# flags: --warn-unreachable from typing import Union x: Union[int, str] if type(x) == int == str: - reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]" + reveal_type(x) # E: Statement is unreachable else: reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]" diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 00d33c86414f..acf2d496204d 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1262,7 +1262,7 @@ def f(t: Type[C]) -> None: if type(t) is M: reveal_type(t) # N: Revealed type is "type[__main__.C]" else: - reveal_type(t) # N: Revealed type is "type[__main__.C]" + reveal_type(t) # N: Revealed type is "type[__main__.C]" if type(t) is not M: reveal_type(t) # N: Revealed type is "type[__main__.C]" else: