@@ -132,7 +132,9 @@ class MockConnection(BaseMockConnection):
132132
133133
134134async def test_get_active_connection_state_lifespan_flaky_ok () -> None :
135- enter = mock .AsyncMock (side_effect = [ConnectionLostError , build_dataclass (EstablishedConnectionResult )])
135+ enter = mock .AsyncMock (
136+ side_effect = [build_dataclass (ConnectionLostError ), build_dataclass (EstablishedConnectionResult )]
137+ )
136138 lifespan_factory = mock .Mock (return_value = mock .Mock (enter = enter ))
137139 manager = EnrichedConnectionManager (lifespan_factory = lifespan_factory , connection_class = BaseMockConnection )
138140
@@ -154,7 +156,7 @@ async def test_get_active_connection_state_lifespan_flaky_ok() -> None:
154156
155157
156158async def test_get_active_connection_state_lifespan_flaky_fails () -> None :
157- enter = mock .AsyncMock (side_effect = ConnectionLostError )
159+ enter = mock .AsyncMock (side_effect = build_dataclass ( ConnectionLostError ) )
158160 lifespan_factory = mock .Mock (return_value = mock .Mock (enter = enter ))
159161 manager = EnrichedConnectionManager (lifespan_factory = lifespan_factory , connection_class = BaseMockConnection )
160162
@@ -209,16 +211,16 @@ async def test_get_active_connection_state_ok_concurrent() -> None:
209211
210212async def test_connection_manager_context_connection_lost () -> None :
211213 async with EnrichedConnectionManager (connection_class = BaseMockConnection ) as manager :
212- manager ._clear_active_connection_state ()
213- manager ._clear_active_connection_state ()
214+ manager ._clear_active_connection_state (build_dataclass ( ConnectionLostError ) )
215+ manager ._clear_active_connection_state (build_dataclass ( ConnectionLostError ) )
214216
215217
216218async def test_connection_manager_context_lifespan_aexit_raises_connection_lost () -> None :
217219 async with EnrichedConnectionManager (
218220 lifespan_factory = mock .Mock (
219221 return_value = mock .Mock (
220222 enter = mock .AsyncMock (return_value = build_dataclass (EstablishedConnectionResult )),
221- exit = mock .AsyncMock (side_effect = [ConnectionLostError ]),
223+ exit = mock .AsyncMock (side_effect = [build_dataclass ( ConnectionLostError ) ]),
222224 )
223225 ),
224226 connection_class = BaseMockConnection ,
@@ -248,7 +250,13 @@ class MockConnection(BaseMockConnection):
248250
249251
250252async def test_write_heartbeat_reconnecting_raises () -> None :
251- write_heartbeat_mock = mock .Mock (side_effect = [ConnectionLostError , ConnectionLostError , ConnectionLostError ])
253+ write_heartbeat_mock = mock .Mock (
254+ side_effect = [
255+ build_dataclass (ConnectionLostError ),
256+ build_dataclass (ConnectionLostError ),
257+ build_dataclass (ConnectionLostError ),
258+ ]
259+ )
252260
253261 class MockConnection (BaseMockConnection ):
254262 write_heartbeat = write_heartbeat_mock
@@ -260,7 +268,13 @@ class MockConnection(BaseMockConnection):
260268
261269
262270async def test_write_frame_reconnecting_raises () -> None :
263- write_frame_mock = mock .AsyncMock (side_effect = [ConnectionLostError , ConnectionLostError , ConnectionLostError ])
271+ write_frame_mock = mock .AsyncMock (
272+ side_effect = [
273+ build_dataclass (ConnectionLostError ),
274+ build_dataclass (ConnectionLostError ),
275+ build_dataclass (ConnectionLostError ),
276+ ]
277+ )
264278
265279 class MockConnection (BaseMockConnection ):
266280 write_frame = write_frame_mock
@@ -271,7 +285,11 @@ class MockConnection(BaseMockConnection):
271285 await manager .write_frame_reconnecting (build_dataclass (ConnectFrame ))
272286
273287
274- SIDE_EFFECTS = [(None ,), (ConnectionLostError (), None ), (ConnectionLostError (), ConnectionLostError (), None )]
288+ SIDE_EFFECTS = [
289+ (None ,),
290+ (build_dataclass (ConnectionLostError ), None ),
291+ (build_dataclass (ConnectionLostError ), build_dataclass (ConnectionLostError ), None ),
292+ ]
275293
276294
277295@pytest .mark .parametrize ("side_effect" , SIDE_EFFECTS )
@@ -318,7 +336,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
318336 attempt += 1
319337 current_effect = side_effect [attempt ]
320338 if isinstance (current_effect , ConnectionLostError ):
321- raise ConnectionLostError
339+ raise current_effect
322340 for frame in frames :
323341 yield frame
324342
@@ -339,7 +357,7 @@ async def test_maybe_write_frame_connection_already_lost() -> None:
339357
340358async def test_maybe_write_frame_connection_now_lost () -> None :
341359 class MockConnection (BaseMockConnection ):
342- write_frame = mock .AsyncMock (side_effect = [ConnectionLostError ])
360+ write_frame = mock .AsyncMock (side_effect = [build_dataclass ( ConnectionLostError ) ])
343361
344362 async with EnrichedConnectionManager (connection_class = MockConnection ) as manager :
345363 assert not await manager .maybe_write_frame (build_dataclass (ConnectFrame ))
0 commit comments