|
| 1 | +import dataclasses |
| 2 | +import logging |
| 3 | +import typing |
| 4 | + |
| 5 | +import tenacity |
| 6 | +from redis import asyncio as aioredis |
| 7 | +from redis.exceptions import ConnectionError as RedisConnectionError |
| 8 | +from redis.exceptions import WatchError |
| 9 | + |
| 10 | +from circuit_breaker_box import BaseCircuitBreaker, errors |
| 11 | + |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +def _log_attempt(retry_state: tenacity.RetryCallState) -> None: |
| 17 | + logger.info("Attempt redis_reconnect: %s", retry_state) |
| 18 | + |
| 19 | + |
| 20 | +@dataclasses.dataclass(kw_only=True, slots=True) |
| 21 | +class CircuitBreakerRedis(BaseCircuitBreaker): |
| 22 | + redis_connection: "aioredis.Redis[str]" |
| 23 | + |
| 24 | + @tenacity.retry( |
| 25 | + stop=tenacity.stop_after_attempt(3), |
| 26 | + wait=tenacity.wait_exponential_jitter(), |
| 27 | + retry=tenacity.retry_if_exception_type((WatchError, RedisConnectionError, ConnectionResetError, TimeoutError)), |
| 28 | + reraise=True, |
| 29 | + before=_log_attempt, |
| 30 | + ) |
| 31 | + async def increment_failures_count(self, host: str) -> None: |
| 32 | + redis_key: typing.Final = f"circuit-breaker-{host}" |
| 33 | + increment_result: int = await self.redis_connection.incr(redis_key) |
| 34 | + logger.debug("Incremented error for redis_key: %s, increment_result: %s", redis_key, increment_result) |
| 35 | + is_expire_set: bool = await self.redis_connection.expire(redis_key, self.reset_timeout_in_seconds) |
| 36 | + logger.debug("Expire set for redis_key: %s, is_expire_set: %s", redis_key, is_expire_set) |
| 37 | + |
| 38 | + @tenacity.retry( |
| 39 | + stop=tenacity.stop_after_attempt(3), |
| 40 | + wait=tenacity.wait_exponential_jitter(), |
| 41 | + retry=tenacity.retry_if_exception_type((WatchError, RedisConnectionError, ConnectionResetError, TimeoutError)), |
| 42 | + reraise=True, |
| 43 | + before=_log_attempt, |
| 44 | + ) |
| 45 | + async def is_host_available(self, host: str) -> bool: |
| 46 | + failures_count: typing.Final = int(await self.redis_connection.get(f"circuit-breaker-{host}") or 0) |
| 47 | + is_available: bool = failures_count <= self.max_failure_count |
| 48 | + logger.debug( |
| 49 | + "host: '%s', failures_count: '%s', self.max_failure_count: '%s', is_available: '%s'", |
| 50 | + host, |
| 51 | + failures_count, |
| 52 | + self.max_failure_count, |
| 53 | + is_available, |
| 54 | + ) |
| 55 | + return is_available |
| 56 | + |
| 57 | + async def raise_host_unavailable_error(self, host: str) -> typing.NoReturn: |
| 58 | + msg = f"Host {host} is unavailable" |
| 59 | + raise errors.HostUnavailableError(msg) |
0 commit comments