|
1 | | -import datetime |
| 1 | +import time |
| 2 | +from functools import wraps |
| 3 | +from typing import Dict, Optional, Tuple, Coroutine |
2 | 4 |
|
3 | | -from .key import KEY |
4 | 5 | from .lru import LRU |
| 6 | +from .types import T, AsyncFunc, Callable, Any |
5 | 7 |
|
6 | 8 |
|
7 | | -class AsyncTTL: |
8 | | - class _TTL(LRU): |
9 | | - def __init__(self, time_to_live, maxsize): |
10 | | - super().__init__(maxsize=maxsize) |
11 | | - |
12 | | - self.time_to_live = ( |
13 | | - datetime.timedelta(seconds=time_to_live) if time_to_live else None |
14 | | - ) |
| 9 | +class TTL(LRU): |
| 10 | + """Time-To-Live (TTL) cache implementation extending LRU cache.""" |
15 | 11 |
|
16 | | - self.maxsize = maxsize |
| 12 | + def __init__(self, maxsize: Optional[int] = 128, time_to_live: int = 0) -> None: |
| 13 | + super().__init__(maxsize=maxsize) |
| 14 | + self.time_to_live: int = time_to_live |
| 15 | + self.timestamps: Dict[Any, float] = {} |
17 | 16 |
|
18 | | - def __contains__(self, key): |
19 | | - if key not in self.keys(): |
| 17 | + async def contains(self, key: Any) -> bool: |
| 18 | + async with self._lock: |
| 19 | + exists = await super().contains(key) |
| 20 | + if not exists: |
20 | 21 | return False |
21 | | - else: |
22 | | - key_expiration = super().__getitem__(key)[1] |
23 | | - if key_expiration and key_expiration < datetime.datetime.now(): |
24 | | - del self[key] |
| 22 | + if self.time_to_live: |
| 23 | + timestamp: float = self.timestamps.get(key, 0) |
| 24 | + if time.time() - timestamp > self.time_to_live: |
| 25 | + del self.cache[key] |
| 26 | + del self.timestamps[key] |
25 | 27 | return False |
26 | | - else: |
27 | | - return True |
| 28 | + return True |
28 | 29 |
|
29 | | - def __getitem__(self, key): |
30 | | - value = super().__getitem__(key)[0] |
31 | | - return value |
| 30 | + async def set(self, key: Any, value: Any) -> None: |
| 31 | + async with self._lock: |
| 32 | + await super().set(key, value) |
| 33 | + if self.time_to_live: |
| 34 | + self.timestamps[key] = time.time() |
32 | 35 |
|
33 | | - def __setitem__(self, key, value): |
34 | | - ttl_value = ( |
35 | | - (datetime.datetime.now() + self.time_to_live) |
36 | | - if self.time_to_live |
37 | | - else None |
38 | | - ) |
39 | | - super().__setitem__(key, (value, ttl_value)) |
| 36 | + async def clear(self) -> None: |
| 37 | + async with self._lock: |
| 38 | + await super().clear() |
| 39 | + self.timestamps.clear() |
40 | 40 |
|
41 | | - def __init__(self, time_to_live=60, maxsize=1024, skip_args: int = 0): |
42 | | - """ |
43 | 41 |
|
44 | | - :param time_to_live: Use time_to_live as None for non expiring cache |
45 | | - :param maxsize: Use maxsize as None for unlimited size cache |
46 | | - :param skip_args: Use `1` to skip first arg of func in determining cache key |
47 | | - """ |
48 | | - self.ttl = self._TTL(time_to_live=time_to_live, maxsize=maxsize) |
49 | | - self.skip_args = skip_args |
50 | | - |
51 | | - def cache_clear(self): |
52 | | - """ |
53 | | - Clears the TTL cache. |
| 42 | +class AsyncTTL: |
| 43 | + """Async Time-To-Live (TTL) cache decorator.""" |
54 | 44 |
|
55 | | - This method empties the cache, removing all stored |
56 | | - entries and effectively resetting the cache. |
| 45 | + def __init__(self, |
| 46 | + time_to_live: int = 0, |
| 47 | + maxsize: Optional[int] = 128, |
| 48 | + skip_args: int = 0) -> None: |
| 49 | + self.ttl: TTL = TTL(maxsize=maxsize, time_to_live=time_to_live) |
| 50 | + self.skip_args: int = skip_args |
57 | 51 |
|
58 | | - :return: None |
59 | | - """ |
60 | | - self.ttl.clear() |
| 52 | + def __call__(self, func: AsyncFunc) -> Callable[..., Coroutine[Any, Any, T]]: |
| 53 | + @wraps(func) |
| 54 | + async def wrapper(*args: Any, use_cache: bool = True, **kwargs: Any) -> T: |
| 55 | + if not use_cache: |
| 56 | + return await func(*args, **kwargs) |
61 | 57 |
|
62 | | - def __call__(self, func): |
63 | | - async def wrapper(*args, use_cache=True, **kwargs): |
64 | | - key = KEY(args[self.skip_args:], kwargs) |
65 | | - if key in self.ttl and use_cache: |
66 | | - val = self.ttl[key] |
67 | | - else: |
68 | | - self.ttl[key] = await func(*args, **kwargs) |
69 | | - val = self.ttl[key] |
| 58 | + key: Tuple[Any, ...] = (*args[self.skip_args:], *sorted(kwargs.items())) |
70 | 59 |
|
71 | | - return val |
| 60 | + if await self.ttl.contains(key): |
| 61 | + return await self.ttl.get(key) |
72 | 62 |
|
73 | | - wrapper.__name__ += func.__name__ |
74 | | - wrapper.__dict__['cache_clear'] = self.cache_clear |
| 63 | + result: T = await func(*args, **kwargs) |
| 64 | + await self.ttl.set(key, result) |
| 65 | + return result |
75 | 66 |
|
| 67 | + # Add cache_clear method to the wrapper |
| 68 | + wrapper.cache_clear = self.ttl.clear # type: ignore |
76 | 69 | return wrapper |
0 commit comments