-
-
Notifications
You must be signed in to change notification settings - Fork 33.5k
Closed as not planned
Closed as not planned
Copy link
Labels
type-featureA feature request or enhancementA feature request or enhancement
Description
Feature or enhancement
Proposal:
Inspired by the singledispatch method in the functools module, I thought it would be very useful to have a method that can also overload methods or functions based on whether they are synchronous or asynchronous.
This can greatly simplify the use of libraries containing both synchronous and asynchronous code that perform the same task but use different dependencies depending on the context. It can also facilitate the migration of synchronous code to asynchronous code.
example of use
import asyncio
class Example:
@coroutinedispatch
def process(self, x: int) -> str:
return f"Processing integer: {x}"
@process.register
async def _(self, x: int) -> str:
await asyncio.sleep(0.1)
return f"Processing integer asynchronously: {x}"
@process.register
def _(self, x: str) -> str:
return f"Processing string: {x}"
@process.register
async def _(self, x: str) -> str:
await asyncio.sleep(0.1)
return f"Processing string asynchronously: {x}"
example = Example()
print(example.process(42)) # Processing integer: 42
print(example.process("hello")) # Processing string: hello
async def main():
print(await example.process(42)) # Processing integer asynchronously: 42
print(await example.process("hello")) # Processing string asynchronously: hello
asyncio.run(main())I had thought that an example code might look something like this
import inspect
import asyncio
from typing import Any, Callable, get_type_hints, get_origin
from functools import lru_cache
class coroutinedispatch:
def __init__(self, func: Callable):
self._sync_methods = {}
self._async_methods = {}
self._name = func.__name__
arg_types = self.get_arg_types(func)
if inspect.iscoroutinefunction(func):
self._async_methods[arg_types] = func
else:
self._sync_methods[arg_types] = func
def get_arg_types(self, func: Callable) -> tuple:
try:
hints = get_type_hints(func)
sig = inspect.signature(func)
params = list(sig.parameters.values())
if params and params[0].name in ("self", "cls"):
params = params[1:]
return tuple(hints.get(p.name, Any) for p in params)
except Exception:
return ()
@lru_cache(maxsize=128)
def match_types(self, provided_args: tuple, expected_types: tuple) -> bool:
if len(provided_args) != len(expected_types):
return False
for arg, expected in zip(provided_args, expected_types):
if expected is Any:
continue
origin = get_origin(expected)
if origin is not None:
expected = origin
if not isinstance(arg, expected):
return False
return True
@lru_cache(maxsize=128)
def _find_matching_method(self, is_async: bool, *args: Any) -> Callable:
"""Find the method that matches the argument types."""
methods = self._async_methods if is_async else self._sync_methods
# Search for exact match
for arg_types, method in methods.items():
if self.match_types(args, arg_types):
return method
# If no match, raise descriptive error
arg_type_names = tuple(type(arg).__name__ for arg in args)
context = "async" if is_async else "sync"
available = list(methods.keys())
raise TypeError(
f"No matching {context} method '{self._name}' found for "
f"arguments: {arg_type_names}. Available: {available}"
)
def register(self, func: Callable) -> Callable:
"""Register a new method overload."""
arg_types = self.get_arg_types(func)
if inspect.iscoroutinefunction(func):
self._async_methods[arg_types] = func
else:
self._sync_methods[arg_types] = func
return self
def __call__(self, *args: Any, **kwargs: Any) -> Any:
try:
asyncio.get_event_loop()
is_async_context = True
except RuntimeError:
is_async_context = False
pass
# Exclude 'self' from arguments if present
check_args = args[1:] if args and hasattr(args[0], self._name) else args
try:
method = self._find_matching_method(is_async_context, *check_args)
result = method(*args, **kwargs)
return result
except TypeError:
# If no async/sync method, try the other
try:
is_async_context = not is_async_context
method = self._find_matching_method(is_async_context, *check_args)
result = method(*args, **kwargs)
return result
except TypeError:
raise
def __get__(self, obj, objtype=None):
"""Support for bound methods"""
if obj is None:
return self
import functools
return functools.partial(self.__call__, obj)Has this already been discussed elsewhere?
No response given
Links to previous discussion of this feature:
No response
Metadata
Metadata
Assignees
Labels
type-featureA feature request or enhancementA feature request or enhancement