-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Open
Description
Say I have a noop decorator:
def trace[**P, T](func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)
return func_with_log
class Foo:
@trace
def foo(self, a: int) -> str:
return "foo"
Foo().foo(1)Great. Now let's replace Callable with a Protocol defining __call__, say if we want to access func.__name__ or access specific args/kwargs being passed in:
class MyCallable[**P, T](Protocol):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]:
@functools.wraps(func)
def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)
return func_with_log
class Foo:
@trace
def foo(self, a: int) -> str:
return "foo"
Foo().foo(1)This fails on mypy with
error: Missing positional argument "a" in call to "__call__" of "MyCallable" [call-arg]
error: Argument 1 to "__call__" of "MyCallable" has incompatible type "int"; expected "Foo" [arg-type]
Per python/typing#1040, we should add __get__ to return a Protocol with the post-bound signature:
class MyCallable[**P, T](Protocol):
def __call__(self_, self: Any, *args: P.args, **kwargs: P.kwargs) -> T: ...
def __get__(self_, *args: Any, **kwargs: Any) -> MyCallableBound[P, T]: ...
class MyCallableBound[**P, T](Protocol):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]:
@functools.wraps(func)
def func_with_log(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
return func(self, *args, **kwargs)
return func_with_log
class Foo:
@trace
def foo(self, a: int) -> str:
return "foo"
Foo().foo(1)Now this fails with:
error: Incompatible return value type (got "_Wrapped[[Any, **P], T, [Any, **P], T]", expected "MyCallable[P, T]") [return-value]
note: "_Wrapped" is missing following "MyCallable" protocol member:
note: __get__
It works if I comment out @functools.wraps(). For now, we can workaround it with
import contextlib
from typing import Callable, TypeVar
WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
'__annotate__', '__type_params__')
WRAPPER_UPDATES = ('__dict__',)
def wraps[T](func: T) -> Callable[[T], T]:
def decorator(new_func: T) -> T:
for attr in WRAPPER_ASSIGNMENTS:
with contextlib.suppress(AttributeError):
setattr(new_func, attr, getattr(func, attr))
for attr in WRAPPER_UPDATES:
getattr(new_func, attr).update(getattr(func, attr, {}))
setattr(new_func, "__wrapped__", func)
return new_func
return decoratorMetadata
Metadata
Assignees
Labels
No labels