2020import socketserver
2121import sys
2222import threading
23+ from asyncio import StreamReader
2324from pathlib import Path
2425
2526sys .path [0 :0 ] = ["" ]
2627
27- from test .asynchronous import AsyncIntegrationTest , AsyncPyMongoTestCase , unittest
28+ from test .asynchronous import AsyncIntegrationTest , AsyncPyMongoTestCase , AsyncUnitTest , unittest
2829from test .asynchronous .pymongo_mocks import DummyMonitor
2930from test .asynchronous .unified_format import generate_test_classes
3031from test .utils import (
@@ -226,7 +227,7 @@ async def run_scenario(self):
226227 return run_scenario
227228
228229
229- def create_tests ():
230+ async def create_tests ():
230231 for dirpath , _ , filenames in os .walk (SDAM_PATH ):
231232 dirname = os .path .split (dirpath )[- 1 ]
232233 # SDAM unified tests are handled separately.
@@ -247,7 +248,6 @@ def create_tests():
247248 setattr (TestAllScenarios , new_test .__name__ , new_test )
248249
249250
250- create_tests ()
251251
252252
253253class TestClusterTimeComparison (unittest .IsolatedAsyncioTestCase ):
@@ -277,45 +277,82 @@ async def send_cluster_time(time, inc, should_update):
277277
278278
279279class TestIgnoreStaleErrors (AsyncIntegrationTest ):
280- @async_client_context .require_sync
281- async def test_ignore_stale_connection_errors (self ):
282- N_THREADS = 5
283- barrier = threading .Barrier (N_THREADS , timeout = 30 )
284- client = await self .async_rs_or_single_client (minPoolSize = N_THREADS )
280+ if _IS_SYNC :
281+ async def test_ignore_stale_connection_errors (self ):
282+ N_THREADS = 5
283+ barrier = threading .Barrier (N_THREADS , timeout = 30 )
284+ client = await self .async_rs_or_single_client (minPoolSize = N_THREADS )
285+
286+ # Wait for initial discovery.
287+ await client .admin .command ("ping" )
288+ pool = await async_get_pool (client )
289+ starting_generation = pool .gen .get_overall ()
290+ await async_wait_until (lambda : len (pool .conns ) == N_THREADS , "created conns" )
291+
292+ def mock_command (* args , ** kwargs ):
293+ # Synchronize all threads to ensure they use the same generation.
294+ barrier .wait ()
295+ raise AutoReconnect ("mock AsyncConnection.command error" )
296+
297+ for conn in pool .conns :
298+ conn .command = mock_command
299+
300+ async def insert_command (i ):
301+ try :
302+ await client .test .command ("insert" , "test" , documents = [{"i" : i }])
303+ except AutoReconnect :
304+ pass
305+
306+ threads = []
307+ for i in range (N_THREADS ):
308+ threads .append (threading .Thread (target = insert_command , args = (i ,)))
309+ for t in threads :
310+ t .start ()
311+ for t in threads :
312+ t .join ()
313+
314+ # Expect a single pool reset for the network error
315+ self .assertEqual (starting_generation + 1 , pool .gen .get_overall ())
316+
317+ # Server should be selectable.
318+ await client .admin .command ("ping" )
319+ else :
320+ async def test_ignore_stale_connection_errors (self ):
321+ N_TASKS = 5
322+ barrier = asyncio .Barrier (N_TASKS )
323+ client = await self .async_rs_or_single_client (minPoolSize = N_TASKS )
285324
286- # Wait for initial discovery.
287- await client .admin .command ("ping" )
288- pool = await async_get_pool (client )
289- starting_generation = pool .gen .get_overall ()
290- await async_wait_until (lambda : len (pool .conns ) == N_THREADS , "created conns" )
291-
292- def mock_command (* args , ** kwargs ):
293- # Synchronize all threads to ensure they use the same generation.
294- barrier .wait ()
295- raise AutoReconnect ("mock AsyncConnection.command error" )
296-
297- for conn in pool .conns :
298- conn .command = mock_command
299-
300- async def insert_command (i ):
301- try :
302- await client .test .command ("insert" , "test" , documents = [{"i" : i }])
303- except AutoReconnect :
304- pass
305-
306- threads = []
307- for i in range (N_THREADS ):
308- threads .append (threading .Thread (target = insert_command , args = (i ,)))
309- for t in threads :
310- t .start ()
311- for t in threads :
312- t .join ()
313-
314- # Expect a single pool reset for the network error
315- self .assertEqual (starting_generation + 1 , pool .gen .get_overall ())
316-
317- # Server should be selectable.
318- await client .admin .command ("ping" )
325+ # Wait for initial discovery.
326+ await client .admin .command ("ping" )
327+ pool = await async_get_pool (client )
328+ starting_generation = pool .gen .get_overall ()
329+ await async_wait_until (lambda : len (pool .conns ) == N_TASKS , "created conns" )
330+
331+ async def mock_command (* args , ** kwargs ):
332+ # Synchronize all threads to ensure they use the same generation.
333+ await asyncio .wait_for (barrier .wait (), timeout = 30 )
334+ raise AutoReconnect ("mock AsyncConnection.command error" )
335+
336+ for conn in pool .conns :
337+ conn .command = mock_command
338+
339+ async def insert_command (i ):
340+ try :
341+ await client .test .command ("insert" , "test" , documents = [{"i" : i }])
342+ except AutoReconnect :
343+ pass
344+
345+ tasks = []
346+ for i in range (N_TASKS ):
347+ tasks .append (asyncio .create_task (insert_command (i )))
348+ for t in tasks :
349+ await t
350+
351+ # Expect a single pool reset for the network error
352+ self .assertEqual (starting_generation + 1 , pool .gen .get_overall ())
353+
354+ # Server should be selectable.
355+ await client .admin .command ("ping" )
319356
320357
321358class CMAPHeartbeatListener (HeartbeatEventListener , CMAPListener ):
@@ -432,30 +469,62 @@ def handle_request_and_shutdown(self):
432469
433470
434471class TestHeartbeatStartOrdering (AsyncPyMongoTestCase ):
435- @async_client_context .require_sync
436- async def test_heartbeat_start_ordering (self ):
437- events = []
438- listener = HeartbeatEventsListListener (events )
439- server = TCPServer (("localhost" , 9999 ), MockTCPHandler )
440- server .events = events
441- server_thread = threading .Thread (target = server .handle_request_and_shutdown )
442- server_thread .start ()
443- _c = await self .simple_client (
444- "mongodb://localhost:9999" , serverSelectionTimeoutMS = 500 , event_listeners = (listener ,)
445- )
446- server_thread .join ()
447- listener .wait_for_event (ServerHeartbeatStartedEvent , 1 )
448- listener .wait_for_event (ServerHeartbeatFailedEvent , 1 )
449-
450- self .assertEqual (
451- events ,
452- [
453- "serverHeartbeatStartedEvent" ,
454- "client connected" ,
455- "client hello received" ,
456- "serverHeartbeatFailedEvent" ,
457- ],
458- )
472+ if _IS_SYNC :
473+ async def test_heartbeat_start_ordering (self ):
474+ events = []
475+ listener = HeartbeatEventsListListener (events )
476+ server = TCPServer (("localhost" , 9999 ), MockTCPHandler )
477+ server .events = events
478+ server_thread = threading .Thread (target = server .handle_request_and_shutdown )
479+ server_thread .start ()
480+ _c = await self .simple_client (
481+ "mongodb://localhost:9999" , serverSelectionTimeoutMS = 500 , event_listeners = (listener ,)
482+ )
483+ server_thread .join ()
484+ listener .wait_for_event (ServerHeartbeatStartedEvent , 1 )
485+ listener .wait_for_event (ServerHeartbeatFailedEvent , 1 )
486+
487+ self .assertEqual (
488+ events ,
489+ [
490+ "serverHeartbeatStartedEvent" ,
491+ "client connected" ,
492+ "client hello received" ,
493+ "serverHeartbeatFailedEvent" ,
494+ ],
495+ )
496+ else :
497+ async def test_heartbeat_start_ordering (self ):
498+ events = []
499+
500+ async def handle_client (reader : StreamReader , writer ):
501+ server .events .append ("client connected" )
502+ print ("clent connected" )
503+ if (await reader .read (1024 )).strip ():
504+ server .events .append ("client hello received" )
505+ print ("client helllo recieved" )
506+ listener = HeartbeatEventsListListener (events )
507+ server = await asyncio .start_server (handle_client , "localhost" , 9999 )
508+ async with server :
509+ server .events = events
510+ _c = self .simple_client (
511+ "mongodb://localhost:9999" , serverSelectionTimeoutMS = 500 , event_listeners = (listener ,)
512+ )
513+ server .close ()
514+ server_task = asyncio .create_task (server .wait_closed ())
515+ await server_task
516+ await listener .async_wait_for_event (ServerHeartbeatStartedEvent , 1 )
517+ await listener .async_wait_for_event (ServerHeartbeatFailedEvent , 1 )
518+
519+ self .assertEqual (
520+ events ,
521+ [
522+ "serverHeartbeatStartedEvent" ,
523+ "client connected" ,
524+ "client hello received" ,
525+ "serverHeartbeatFailedEvent" ,
526+ ],
527+ )
459528
460529
461530# Generate unified tests.
0 commit comments