Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet
from contextlib import ExitStack, contextmanager
from functools import reduce
from typing import (
Callable,
Final,
Expand Down Expand Up @@ -166,6 +167,7 @@
is_more_precise,
is_proper_subtype,
is_same_type,
is_same_type_ranges,
is_subtype,
restrict_subtype_away,
unify_generic_callable,
Expand Down Expand Up @@ -6251,9 +6253,12 @@ def is_type_call(expr: CallExpr) -> bool:
current_type = self.get_isinstance_type(expr)
if current_type is None:
continue
if type_being_compared is not None:
if type_being_compared is not None and not is_same_type_ranges(
type_being_compared, current_type
):
# It doesn't really make sense to have several types being
# compared to the output of type (like type(x) == int == str)
# unless they are the same (like type(x) == float == float)
# because whether that's true is solely dependent on what the
# types being compared are, so we don't try to narrow types any
# further because we can't really get any information about the
Expand All @@ -6267,6 +6272,12 @@ def is_type_call(expr: CallExpr) -> bool:
if not exprs_in_type_calls:
return {}, {}

if type_being_compared is None:
least_type = reduce(
meet_types, (self.lookup_type(expr) for expr in exprs_in_type_calls)
)
type_being_compared = [TypeRange(least_type, is_upper_bound=True)]

if_maps: list[TypeMap] = []
else_maps: list[TypeMap] = []
for expr in exprs_in_type_calls:
Expand Down
23 changes: 23 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mypy.applytype
import mypy.constraints
import mypy.typeops
from mypy.checker_shared import TypeRange
from mypy.checker_state import checker_state
from mypy.erasetype import erase_type
from mypy.expandtype import (
Expand Down Expand Up @@ -255,6 +256,28 @@ def is_equivalent(
)


def is_same_type_ranges(
a: list[TypeRange],
b: list[TypeRange],
ignore_promotions: bool = True,
subtype_context: SubtypeContext | None = None,
) -> bool:
return len(a) == len(b) and all(
is_same_type_range(x, y, ignore_promotions, subtype_context) for x, y in zip(a, b)
)


def is_same_type_range(
a: TypeRange,
b: TypeRange,
ignore_promotions: bool = True,
subtype_context: SubtypeContext | None = None,
) -> bool:
return a.is_upper_bound == b.is_upper_bound and is_same_type(
a.item, b.item, ignore_promotions, subtype_context
)


def is_same_type(
a: Type, b: Type, ignore_promotions: bool = True, subtype_context: SubtypeContext | None = None
) -> bool:
Expand Down
30 changes: 30 additions & 0 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2716,13 +2716,43 @@ if type(x) == type(y) == int:
reveal_type(y) # N: Revealed type is "builtins.int"
reveal_type(x) # N: Revealed type is "builtins.int"

z: Any
if int == type(z) == int:
reveal_type(z) # N: Revealed type is "builtins.int"

[case testTypeEqualsCheckUsingIs]
from typing import Any

y: Any
if type(y) is int:
reveal_type(y) # N: Revealed type is "builtins.int"

[case testTypeEqualsCheckUsingImplicitTypes]
from typing import Any

x: str
y: Any
z: object
if type(y) is type(x):
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(y) # N: Revealed type is "builtins.str"

if type(x) is type(z):
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(z) # N: Revealed type is "builtins.str"

[case testTypeEqualsCheckUsingDifferentSpecializedTypes]
from collections import defaultdict

x: defaultdict
y: dict
z: object
if type(x) is type(y) is type(z):
reveal_type(x) # N: Revealed type is "collections.defaultdict[Any, Any]"
reveal_type(y) # N: Revealed type is "collections.defaultdict[Any, Any]"
reveal_type(z) # N: Revealed type is "collections.defaultdict[Any, Any]"


[case testTypeEqualsCheckUsingIsNonOverlapping]
# flags: --warn-unreachable
from typing import Union
Expand Down