Skip to content

Commit ece4d41

Browse files
authored
fix: unpack unions inside tuples in except handlers (#17762)
Fixes #17759 mypy unpacks only top-level unions in except handlers, leaving unions inside tuples unchanged. This leads to the failed check because the Union type isn't a subtype of the BaseException type. We can fix this by simplifying and unpacking types inside tuples.
1 parent 4f5425e commit ece4d41

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

mypy/checker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5298,16 +5298,22 @@ def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
52985298
"""Helper for check_except_handler_test to retrieve handler types."""
52995299
typ = get_proper_type(typ)
53005300
if isinstance(typ, TupleType):
5301-
return typ.items
5301+
merged_type = make_simplified_union(typ.items)
5302+
if isinstance(merged_type, UnionType):
5303+
return merged_type.relevant_items()
5304+
return [merged_type]
5305+
elif is_named_instance(typ, "builtins.tuple"):
5306+
# variadic tuple
5307+
merged_type = make_simplified_union((typ.args[0],))
5308+
if isinstance(merged_type, UnionType):
5309+
return merged_type.relevant_items()
5310+
return [merged_type]
53025311
elif isinstance(typ, UnionType):
53035312
return [
53045313
union_typ
53055314
for item in typ.relevant_items()
53065315
for union_typ in self.get_types_from_except_handler(item, n)
53075316
]
5308-
elif is_named_instance(typ, "builtins.tuple"):
5309-
# variadic tuple
5310-
return [typ.args[0]]
53115317
else:
53125318
return [typ]
53135319

mypy/fastparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
import warnings
66
from collections.abc import Sequence
7-
from typing import Any, Callable, Final, Literal, TypeVar, Union, cast, overload
7+
from typing import Any, Callable, Final, Literal, TypeVar, cast, overload
88

99
from mypy import defaults, errorcodes as codes, message_registry
1010
from mypy.errors import Errors

test-data/unit/check-statements.test

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,57 @@ def error_in_variadic(exc: Tuple[int, ...]) -> None:
801801

802802
[builtins fixtures/tuple.pyi]
803803

804+
[case testExceptWithMultipleTypes5]
805+
from typing import Tuple, Type, Union
806+
807+
class E1(BaseException): pass
808+
class E2(BaseException): pass
809+
class E3(BaseException): pass
810+
811+
def union_in_variadic(exc: Tuple[Union[Type[E1], Type[E2]], ...]) -> None:
812+
try:
813+
pass
814+
except exc as e:
815+
reveal_type(e) # N: Revealed type is "Union[__main__.E1, __main__.E2]"
816+
817+
def nested_union_in_variadic(exc: Tuple[Union[Type[E1], Union[Type[E2], Type[E3]]], ...]) -> None:
818+
try:
819+
pass
820+
except exc as e:
821+
reveal_type(e) # N: Revealed type is "Union[__main__.E1, __main__.E2, __main__.E3]"
822+
823+
def union_in_tuple(exc: Tuple[Union[Type[E1], Type[E2]], Type[E3]]) -> None:
824+
try:
825+
pass
826+
except exc as e:
827+
reveal_type(e) # N: Revealed type is "Union[__main__.E1, __main__.E2, __main__.E3]"
828+
829+
def error_in_variadic_union(exc: Tuple[Union[Type[E1], int], ...]) -> None:
830+
try:
831+
pass
832+
except exc as e: # E: Exception type must be derived from BaseException (or be a tuple of exception classes)
833+
pass
834+
835+
def error_in_variadic_nested_union(exc: Tuple[Union[Type[E1], Union[Type[E2], int]], ...]) -> None:
836+
try:
837+
pass
838+
except exc as e: # E: Exception type must be derived from BaseException (or be a tuple of exception classes)
839+
pass
840+
841+
def error_in_tuple_inside_variadic_nested_union(exc: Tuple[Union[Type[E1], Union[Type[E2], Tuple[Type[E3]]]], ...]) -> None:
842+
try:
843+
pass
844+
except exc as e: # E: Exception type must be derived from BaseException (or be a tuple of exception classes)
845+
pass
846+
847+
def error_in_tuple_union(exc: Tuple[Union[Type[E1], Type[E2]], Union[Type[E3], int]]) -> None:
848+
try:
849+
pass
850+
except exc as e: # E: Exception type must be derived from BaseException (or be a tuple of exception classes)
851+
pass
852+
853+
[builtins fixtures/tuple.pyi]
854+
804855
[case testExceptWithAnyTypes]
805856
from typing import Any
806857

0 commit comments

Comments
 (0)