1+ import asyncio
12from collections .abc import Awaitable , Callable , Coroutine
23from dataclasses import dataclass , field
34from typing import Any
1415 UnsubscribeFrame ,
1516)
1617
17- ActiveSubscriptions = dict [str , "AutoAckSubscription | ManualAckSubscription" ]
18+
19+ @dataclass (kw_only = True , slots = True , frozen = True )
20+ class ActiveSubscriptions :
21+ subscriptions : dict [str , "AutoAckSubscription | ManualAckSubscription" ] = field (default_factory = dict , init = False )
22+ event : asyncio .Event = field (default_factory = asyncio .Event , init = False )
23+
24+ def __post_init__ (self ) -> None :
25+ self .event .set ()
26+
27+ def get_by_id (self , subscription_id : str ) -> "AutoAckSubscription | ManualAckSubscription | None" :
28+ return self .subscriptions .get (subscription_id )
29+
30+ def get_all (self ) -> list ["AutoAckSubscription | ManualAckSubscription" ]:
31+ return list (self .subscriptions .values ())
32+
33+ def delete_by_id (self , subscription_id : str ) -> None :
34+ del self .subscriptions [subscription_id ]
35+ if not self .subscriptions :
36+ self .event .set ()
37+
38+ def add (self , subscription : "AutoAckSubscription | ManualAckSubscription" ) -> None :
39+ self .subscriptions [subscription .id ] = subscription
40+ self .event .clear ()
41+
42+ def contains_by_id (self , subscription_id : str ) -> bool :
43+ return subscription_id in self .subscriptions
44+
45+ async def wait_until_empty (self ) -> bool :
46+ return await self .event .wait ()
1847
1948
2049@dataclass (kw_only = True , slots = True )
@@ -32,20 +61,20 @@ async def _subscribe(self) -> None:
3261 subscription_id = self .id , destination = self .destination , ack = self .ack , headers = self .headers
3362 )
3463 )
35- self ._active_subscriptions [ self . id ] = self # type: ignore[assignment ]
64+ self ._active_subscriptions . add ( self ) # type: ignore[arg-type ]
3665
3766 async def unsubscribe (self ) -> None :
38- del self ._active_subscriptions [ self .id ]
67+ self ._active_subscriptions . delete_by_id ( self .id )
3968 await self ._connection_manager .maybe_write_frame (UnsubscribeFrame (headers = {"id" : self .id }))
4069
4170 async def _nack (self , frame : MessageFrame ) -> None :
42- if self .id in self ._active_subscriptions and (ack_id := frame .headers .get ("ack" )):
71+ if self ._active_subscriptions . contains_by_id ( self .id ) and (ack_id := frame .headers .get ("ack" )):
4372 await self ._connection_manager .maybe_write_frame (
4473 NackFrame (headers = {"id" : ack_id , "subscription" : frame .headers ["subscription" ]})
4574 )
4675
4776 async def _ack (self , frame : MessageFrame ) -> None :
48- if self .id in self ._active_subscriptions and (ack_id := frame .headers ["ack" ]):
77+ if self ._active_subscriptions . contains_by_id ( self .id ) and (ack_id := frame .headers ["ack" ]):
4978 await self ._connection_manager .maybe_write_frame (
5079 AckFrame (headers = {"id" : ack_id , "subscription" : frame .headers ["subscription" ]})
5180 )
@@ -96,7 +125,7 @@ def _make_subscription_id() -> str:
96125async def resubscribe_to_active_subscriptions (
97126 * , connection : AbstractConnection , active_subscriptions : ActiveSubscriptions
98127) -> None :
99- for subscription in active_subscriptions .values ():
128+ for subscription in active_subscriptions .get_all ():
100129 await connection .write_frame (
101130 SubscribeFrame .build (
102131 subscription_id = subscription .id ,
@@ -108,5 +137,5 @@ async def resubscribe_to_active_subscriptions(
108137
109138
110139async def unsubscribe_from_all_active_subscriptions (* , active_subscriptions : ActiveSubscriptions ) -> None :
111- for subscription in active_subscriptions .copy (). values ():
140+ for subscription in active_subscriptions .get_all ():
112141 await subscription .unsubscribe ()
0 commit comments