Skip to content

How to use functools.wraps with method decorators #15087

@sfc-gh-bchinn

Description

@sfc-gh-bchinn

Say I have a noop decorator:

def trace[**P, T](func: Callable[P, T]) -> Callable[P, T]:
    @functools.wraps(func)
    def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T:
        return func(*args, **kwargs)
    return func_with_log

class Foo:
    @trace
    def foo(self, a: int) -> str:
        return "foo"

Foo().foo(1)

Great. Now let's replace Callable with a Protocol defining __call__, say if we want to access func.__name__ or access specific args/kwargs being passed in:

class MyCallable[**P, T](Protocol):
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]:
    @functools.wraps(func)
    def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T:
        return func(*args, **kwargs)
    return func_with_log

class Foo:
    @trace
    def foo(self, a: int) -> str:
        return "foo"

Foo().foo(1)

This fails on mypy with

error: Missing positional argument "a" in call to "__call__" of "MyCallable"  [call-arg]
error: Argument 1 to "__call__" of "MyCallable" has incompatible type "int"; expected "Foo"  [arg-type]

Per python/typing#1040, we should add __get__ to return a Protocol with the post-bound signature:

class MyCallable[**P, T](Protocol):
    def __call__(self_, self: Any, *args: P.args, **kwargs: P.kwargs) -> T: ...
    def __get__(self_, *args: Any, **kwargs: Any) -> MyCallableBound[P, T]: ...
class MyCallableBound[**P, T](Protocol):
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]:
    @functools.wraps(func)
    def func_with_log(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
        return func(self, *args, **kwargs)
    return func_with_log

class Foo:
    @trace
    def foo(self, a: int) -> str:
        return "foo"

Foo().foo(1)

Now this fails with:

error: Incompatible return value type (got "_Wrapped[[Any, **P], T, [Any, **P], T]", expected "MyCallable[P, T]")  [return-value]
note: "_Wrapped" is missing following "MyCallable" protocol member:
note:     __get__

It works if I comment out @functools.wraps(). For now, we can workaround it with

import contextlib
from typing import Callable, TypeVar

WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
                       '__annotate__', '__type_params__')
WRAPPER_UPDATES = ('__dict__',)

def wraps[T](func: T) -> Callable[[T], T]:
    def decorator(new_func: T) -> T:
        for attr in WRAPPER_ASSIGNMENTS:
            with contextlib.suppress(AttributeError):
                setattr(new_func, attr, getattr(func, attr))
        for attr in WRAPPER_UPDATES:
            getattr(new_func, attr).update(getattr(func, attr, {}))
        setattr(new_func, "__wrapped__", func)
        return new_func

    return decorator

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions