Skip to content

Proposal for a new way to overload methods by arguments, and whether it is synchronous or asynchronous #141628

@rroblf01

Description

@rroblf01

Feature or enhancement

Proposal:

Inspired by the singledispatch method in the functools module, I thought it would be very useful to have a method that can also overload methods or functions based on whether they are synchronous or asynchronous.

This can greatly simplify the use of libraries containing both synchronous and asynchronous code that perform the same task but use different dependencies depending on the context. It can also facilitate the migration of synchronous code to asynchronous code.

example of use

import asyncio


class Example:
    @coroutinedispatch
    def process(self, x: int) -> str:
        return f"Processing integer: {x}"

    @process.register
    async def _(self, x: int) -> str:
        await asyncio.sleep(0.1)
        return f"Processing integer asynchronously: {x}"

    @process.register
    def _(self, x: str) -> str:
        return f"Processing string: {x}"

    @process.register
    async def _(self, x: str) -> str:
        await asyncio.sleep(0.1)
        return f"Processing string asynchronously: {x}"


example = Example()
print(example.process(42))  # Processing integer: 42
print(example.process("hello"))  # Processing string: hello


async def main():
    print(await example.process(42))  # Processing integer asynchronously: 42
    print(await example.process("hello"))  # Processing string asynchronously: hello


asyncio.run(main())

I had thought that an example code might look something like this

import inspect
import asyncio
from typing import Any, Callable, get_type_hints, get_origin
from functools import lru_cache


class coroutinedispatch:
    def __init__(self, func: Callable):
        self._sync_methods = {}
        self._async_methods = {}
        self._name = func.__name__

        arg_types = self.get_arg_types(func)

        if inspect.iscoroutinefunction(func):
            self._async_methods[arg_types] = func
        else:
            self._sync_methods[arg_types] = func

    def get_arg_types(self, func: Callable) -> tuple:
        try:
            hints = get_type_hints(func)
            sig = inspect.signature(func)
            params = list(sig.parameters.values())

            if params and params[0].name in ("self", "cls"):
                params = params[1:]

            return tuple(hints.get(p.name, Any) for p in params)
        except Exception:
            return ()

    @lru_cache(maxsize=128)
    def match_types(self, provided_args: tuple, expected_types: tuple) -> bool:
        if len(provided_args) != len(expected_types):
            return False

        for arg, expected in zip(provided_args, expected_types):
            if expected is Any:
                continue

            origin = get_origin(expected)
            if origin is not None:
                expected = origin

            if not isinstance(arg, expected):
                return False

        return True

    @lru_cache(maxsize=128)
    def _find_matching_method(self, is_async: bool, *args: Any) -> Callable:
        """Find the method that matches the argument types."""
        methods = self._async_methods if is_async else self._sync_methods

        # Search for exact match
        for arg_types, method in methods.items():
            if self.match_types(args, arg_types):
                return method

        # If no match, raise descriptive error
        arg_type_names = tuple(type(arg).__name__ for arg in args)
        context = "async" if is_async else "sync"
        available = list(methods.keys())

        raise TypeError(
            f"No matching {context} method '{self._name}' found for "
            f"arguments: {arg_type_names}. Available: {available}"
        )

    def register(self, func: Callable) -> Callable:
        """Register a new method overload."""
        arg_types = self.get_arg_types(func)

        if inspect.iscoroutinefunction(func):
            self._async_methods[arg_types] = func
        else:
            self._sync_methods[arg_types] = func

        return self

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        try:
            asyncio.get_event_loop()
            is_async_context = True
        except RuntimeError:
            is_async_context = False
            pass

        # Exclude 'self' from arguments if present
        check_args = args[1:] if args and hasattr(args[0], self._name) else args

        try:
            method = self._find_matching_method(is_async_context, *check_args)
            result = method(*args, **kwargs)

            return result
        except TypeError:
            # If no async/sync method, try the other
            try:
                is_async_context = not is_async_context
                method = self._find_matching_method(is_async_context, *check_args)
                result = method(*args, **kwargs)

                return result
            except TypeError:
                raise

    def __get__(self, obj, objtype=None):
        """Support for bound methods"""
        if obj is None:
            return self

        import functools

        return functools.partial(self.__call__, obj)

Has this already been discussed elsewhere?

No response given

Links to previous discussion of this feature:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    type-featureA feature request or enhancement

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions