diff --git a/async_lru/__init__.py b/async_lru/__init__.py index 447e9cd..b712c75 100644 --- a/async_lru/__init__.py +++ b/async_lru/__init__.py @@ -9,9 +9,9 @@ Coroutine, Generic, Hashable, + List, Optional, OrderedDict, - Set, Type, TypedDict, TypeVar, @@ -54,8 +54,9 @@ class _CacheParameters(TypedDict): @final @dataclasses.dataclass class _CacheItem(Generic[_R]): - fut: "asyncio.Future[_R]" + task: "asyncio.Task[_R]" later_call: Optional[asyncio.Handle] + waiters: int def cancel(self) -> None: if self.later_call is not None: @@ -108,7 +109,17 @@ def __init__( self.__closed = False self.__hits = 0 self.__misses = 0 - self.__tasks: Set["asyncio.Task[_R]"] = set() + + @property + def __tasks(self) -> List["asyncio.Task[_R]"]: + # NOTE: I don't think we need to form a set first here but not too sure we want it for guarantees + return list( + { + cache_item.task + for cache_item in self.__cache.values() + if not cache_item.task.done() + } + ) def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool: key = _make_key(args, kwargs, self.__typed) @@ -128,12 +139,11 @@ def cache_clear(self) -> None: if c.later_call: c.later_call.cancel() self.__cache.clear() - self.__tasks.clear() async def cache_close(self, *, wait: bool = False) -> None: self.__closed = True - tasks = list(self.__tasks) + tasks = self.__tasks if not tasks: return @@ -167,19 +177,8 @@ def _cache_hit(self, key: Hashable) -> None: def _cache_miss(self, key: Hashable) -> None: self.__misses += 1 - def _task_done_callback( - self, fut: "asyncio.Future[_R]", key: Hashable, task: "asyncio.Task[_R]" - ) -> None: - self.__tasks.discard(task) - - if task.cancelled(): - fut.cancel() - self.__cache.pop(key, None) - return - - exc = task.exception() - if exc is not None: - fut.set_exception(exc) + def _task_done_callback(self, key: Hashable, task: "asyncio.Task[_R]") -> None: + if task.cancelled() or task.exception() is not None: self.__cache.pop(key, None) return @@ -190,7 +189,16 @@ def _task_done_callback( self.__ttl, self.__cache.pop, key, None ) - fut.set_result(task.result()) + def _handle_cancelled_error( + self, key: Hashable, cache_item: "_CacheItem[Any]" + ) -> None: + # Called when a waiter is cancelled. + # If this is the last waiter and the underlying task is not done, + # cancel the underlying task and remove the cache entry. + if cache_item.waiters == 1 and not cache_item.task.done(): + cache_item.cancel() # Cancel TTL expiration + cache_item.task.cancel() # Cancel the running coroutine + self.__cache.pop(key, None) # Remove from cache async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: if self.__closed: @@ -204,25 +212,43 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: if cache_item is not None: self._cache_hit(key) - if not cache_item.fut.done(): - return await asyncio.shield(cache_item.fut) - - return cache_item.fut.result() + if not cache_item.task.done(): + # Each logical waiter increments waiters on entry. + cache_item.waiters += 1 + + try: + # All waiters await the same shielded task. + return await asyncio.shield(cache_item.task) + except asyncio.CancelledError: + # If a waiter is cancelled, handle possible last-waiter cleanup. + self._handle_cancelled_error(key, cache_item) + raise + finally: + # Each logical waiter decrements waiters on exit (normal or cancelled). + cache_item.waiters -= 1 + # If the task is already done, just return the result. + return cache_item.task.result() - fut = loop.create_future() coro = self.__wrapped__(*fn_args, **fn_kwargs) task: asyncio.Task[_R] = loop.create_task(coro) - self.__tasks.add(task) - task.add_done_callback(partial(self._task_done_callback, fut, key)) + task.add_done_callback(partial(self._task_done_callback, key)) - self.__cache[key] = _CacheItem(fut, None) + cache_item = _CacheItem(task, None, 1) + self.__cache[key] = cache_item if self.__maxsize is not None and len(self.__cache) > self.__maxsize: - dropped_key, cache_item = self.__cache.popitem(last=False) - cache_item.cancel() + dropped_key, dropped_cache_item = self.__cache.popitem(last=False) + dropped_cache_item.cancel() self._cache_miss(key) - return await asyncio.shield(fut) + + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + self._handle_cancelled_error(key, cache_item) + raise + finally: + cache_item.waiters -= 1 def __get__( self, instance: _T, owner: Optional[Type[_T]] diff --git a/benchmark.py b/benchmark.py index b01e0e8..65b1971 100644 --- a/benchmark.py +++ b/benchmark.py @@ -305,11 +305,10 @@ async def dummy_coro(): pass iterations = range(1000) - create_future = loop.create_future callback_fn = func._task_done_callback @benchmark def run() -> None: for i in iterations: - callback = partial(callback_fn, create_future(), i) + callback = partial(callback_fn, i) callback(task) diff --git a/tests/test_basic.py b/tests/test_basic.py index ef234f0..06ddbd5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -152,8 +152,8 @@ async def coro(val: int) -> int: assert ret1 == ret2 assert ( - coro1._LRUCacheWrapper__cache[1].fut.result() # type: ignore[attr-defined] - == coro2._LRUCacheWrapper__cache[1].fut.result() # type: ignore[attr-defined] + coro1._LRUCacheWrapper__cache[1].task.result() # type: ignore[attr-defined] + == coro2._LRUCacheWrapper__cache[1].task.result() # type: ignore[attr-defined] ) assert coro1._LRUCacheWrapper__cache != coro2._LRUCacheWrapper__cache # type: ignore[attr-defined] assert coro1._LRUCacheWrapper__cache.keys() == coro2._LRUCacheWrapper__cache.keys() # type: ignore[attr-defined] diff --git a/tests/test_cancel.py b/tests/test_cancel.py new file mode 100644 index 0000000..ee405c7 --- /dev/null +++ b/tests/test_cancel.py @@ -0,0 +1,58 @@ +import asyncio + +import pytest + +from async_lru import alru_cache + + +@pytest.mark.parametrize("num_to_cancel", [0, 1, 2, 3]) +async def test_cancel(num_to_cancel: int) -> None: + cache_item_task_finished = False + + @alru_cache + async def coro(val: int) -> int: + # I am a long running coro function + nonlocal cache_item_task_finished + await asyncio.sleep(2) + cache_item_task_finished = True + return val + + # create 3 tasks for the cached function using the same key + tasks = [asyncio.create_task(coro(1)) for _ in range(3)] + + # force the event loop to run once so the tasks can begin + await asyncio.sleep(0) + + # maybe cancel some tasks + for i in range(num_to_cancel): + tasks[i].cancel() + + # allow enough time for the non-cancelled tasks to complete + await asyncio.sleep(3) + + # check state + assert cache_item_task_finished == (num_to_cancel < 3) + + +@pytest.mark.asyncio +async def test_cancel_single_waiter_triggers_handle_cancelled_error() -> None: + # This test ensures the _handle_cancelled_error path (waiters == 1) is exercised. + cache_item_task_finished = False + + @alru_cache + async def coro(val: int) -> int: + nonlocal cache_item_task_finished + await asyncio.sleep(2) + cache_item_task_finished = True + return val + + task = asyncio.create_task(coro(42)) + await asyncio.sleep(0) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # The underlying coroutine should be cancelled, so the flag should remain False + assert cache_item_task_finished is False diff --git a/tests/test_internals.py b/tests/test_internals.py index e5a055c..cd63c09 100644 --- a/tests/test_internals.py +++ b/tests/test_internals.py @@ -11,30 +11,26 @@ async def test_done_callback_cancelled() -> None: wrapped = _LRUCacheWrapper(mock.ANY, None, False, None) loop = asyncio.get_running_loop() task = loop.create_future() - fut = loop.create_future() key = 1 - task.add_done_callback(partial(wrapped._task_done_callback, fut, key)) - wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined] + task.add_done_callback(partial(wrapped._task_done_callback, key)) task.cancel() await asyncio.sleep(0) - assert fut.cancelled() + assert task not in wrapped._LRUCacheWrapper__tasks # type: ignore[attr-defined] async def test_done_callback_exception() -> None: wrapped = _LRUCacheWrapper(mock.ANY, None, False, None) loop = asyncio.get_running_loop() task = loop.create_future() - fut = loop.create_future() key = 1 - task.add_done_callback(partial(wrapped._task_done_callback, fut, key)) - wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined] + task.add_done_callback(partial(wrapped._task_done_callback, key)) exc = ZeroDivisionError() @@ -42,31 +38,7 @@ async def test_done_callback_exception() -> None: await asyncio.sleep(0) - with pytest.raises(ZeroDivisionError): - await fut - - with pytest.raises(ZeroDivisionError): - fut.result() - - assert fut.exception() is exc - - -async def test_done_callback() -> None: - wrapped = _LRUCacheWrapper(mock.ANY, None, False, None) - loop = asyncio.get_running_loop() - task = loop.create_future() - - key = 1 - fut = loop.create_future() - - task.add_done_callback(partial(wrapped._task_done_callback, fut, key)) - wrapped._LRUCacheWrapper__tasks.add(task) # type: ignore[attr-defined] - - task.set_result(1) - - await asyncio.sleep(0) - - assert fut.result() == 1 + assert task not in wrapped._LRUCacheWrapper__tasks # type: ignore[attr-defined] async def test_cache_invalidate_typed() -> None: