1515"""Execute Transactions Spec tests."""
1616from __future__ import annotations
1717
18+ import asyncio
1819import os
1920import sys
2021import time
2324sys .path [0 :0 ] = ["" ]
2425
2526from test .asynchronous import AsyncIntegrationTest , client_knobs , unittest
26- from test .asynchronous .utils_spec_runner import AsyncSpecTestCreator , SpecRunnerThread
27- from test .pymongo_mocks import DummyMonitor
27+ from test .asynchronous .pymongo_mocks import DummyMonitor
28+ from test .asynchronous . utils_spec_runner import AsyncSpecTestCreator , SpecRunnerTask
2829from test .utils import (
2930 CMAPListener ,
3031 async_client_context ,
@@ -91,23 +92,23 @@ class AsyncTestCMAP(AsyncIntegrationTest):
9192
9293 # Test operations:
9394
94- def start (self , op ):
95+ async def start (self , op ):
9596 """Run the 'start' thread operation."""
9697 target = op ["target" ]
97- thread = SpecRunnerThread (target )
98- thread .start ()
98+ thread = SpecRunnerTask (target )
99+ await thread .start ()
99100 self .targets [target ] = thread
100101
101- def wait (self , op ):
102+ async def wait (self , op ):
102103 """Run the 'wait' operation."""
103- time .sleep (op ["ms" ] / 1000.0 )
104+ await asyncio .sleep (op ["ms" ] / 1000.0 )
104105
105- def wait_for_thread (self , op ):
106+ async def wait_for_thread (self , op ):
106107 """Run the 'waitForThread' operation."""
107108 target = op ["target" ]
108109 thread = self .targets [target ]
109- thread .stop ()
110- thread .join ()
110+ await thread .stop ()
111+ await thread .join ()
111112 if thread .exc :
112113 raise thread .exc
113114 self .assertFalse (thread .ops )
@@ -123,53 +124,53 @@ async def wait_for_event(self, op):
123124 timeout = timeout ,
124125 )
125126
126- def check_out (self , op ):
127+ async def check_out (self , op ):
127128 """Run the 'checkOut' operation."""
128129 label = op ["label" ]
129- with self .pool .checkout () as conn :
130+ async with self .pool .checkout () as conn :
130131 # Call 'pin_cursor' so we can hold the socket.
131132 conn .pin_cursor ()
132133 if label :
133134 self .labels [label ] = conn
134135 else :
135136 self .addAsyncCleanup (conn .close_conn , None )
136137
137- def check_in (self , op ):
138+ async def check_in (self , op ):
138139 """Run the 'checkIn' operation."""
139140 label = op ["connection" ]
140141 conn = self .labels [label ]
141- self .pool .checkin (conn )
142+ await self .pool .checkin (conn )
142143
143- def ready (self , op ):
144+ async def ready (self , op ):
144145 """Run the 'ready' operation."""
145- self .pool .ready ()
146+ await self .pool .ready ()
146147
147- def clear (self , op ):
148+ async def clear (self , op ):
148149 """Run the 'clear' operation."""
149- if "interruptInUseAsyncConnections " in op :
150- self .pool .reset (interrupt_connections = op ["interruptInUseAsyncConnections " ])
150+ if "interruptInUseConnections " in op :
151+ await self .pool .reset (interrupt_connections = op ["interruptInUseConnections " ])
151152 else :
152- self .pool .reset ()
153+ await self .pool .reset ()
153154
154155 async def close (self , op ):
155- """Run the 'aclose ' operation."""
156- await self .pool .aclose ()
156+ """Run the 'close ' operation."""
157+ await self .pool .close ()
157158
158- def run_operation (self , op ):
159+ async def run_operation (self , op ):
159160 """Run a single operation in a test."""
160161 op_name = camel_to_snake (op ["name" ])
161162 thread = op ["thread" ]
162163 meth = getattr (self , op_name )
163164 if thread :
164- self .targets [thread ].schedule (lambda : meth (op ))
165+ await self .targets [thread ].schedule (lambda : meth (op ))
165166 else :
166- meth (op )
167+ await meth (op )
167168
168- def run_operations (self , ops ):
169+ async def run_operations (self , ops ):
169170 """Run a test's operations."""
170171 for op in ops :
171172 self ._ops .append (op )
172- self .run_operation (op )
173+ await self .run_operation (op )
173174
174175 def check_object (self , actual , expected ):
175176 """Assert that the actual object matches the expected object."""
@@ -215,10 +216,10 @@ async def _set_fail_point(self, client, command_args):
215216 cmd .update (command_args )
216217 await client .admin .command (cmd )
217218
218- def set_fail_point (self , command_args ):
219+ async def set_fail_point (self , command_args ):
219220 if not async_client_context .supports_failCommand_fail_point :
220221 self .skipTest ("failCommand fail point must be supported" )
221- self ._set_fail_point (self .client , command_args )
222+ await self ._set_fail_point (self .client , command_args )
222223
223224 async def run_scenario (self , scenario_def , test ):
224225 """Run a CMAP spec test."""
@@ -231,7 +232,7 @@ async def run_scenario(self, scenario_def, test):
231232 # Configure the fail point before creating the client.
232233 if "failPoint" in test :
233234 fp = test ["failPoint" ]
234- self .set_fail_point (fp )
235+ await self .set_fail_point (fp )
235236 self .addAsyncCleanup (
236237 self .set_fail_point , {"configureFailPoint" : fp ["configureFailPoint" ], "mode" : "off" }
237238 )
@@ -254,16 +255,18 @@ async def run_scenario(self, scenario_def, test):
254255 # PoolReadyEvents. Instead, update the initial state before
255256 # opening the Topology.
256257 td = async_client_context .client ._topology .description
257- sd = td .server_descriptions ()[(async_client_context .host , async_client_context .port )]
258+ sd = td .server_descriptions ()[
259+ (await async_client_context .host , await async_client_context .port )
260+ ]
258261 client ._topology ._description = updated_topology_description (
259262 client ._topology ._description , sd
260263 )
261264 # When backgroundThreadIntervalMS is negative we do not start the
262265 # background thread to ensure it never runs.
263266 if interval < 0 :
264- client ._topology .open ()
267+ await client ._topology .open ()
265268 else :
266- client ._get_topology ()
269+ await client ._get_topology ()
267270 self .pool = list (client ._topology ._servers .values ())[0 ].pool
268271
269272 # Map of target names to Thread objects.
@@ -273,21 +276,21 @@ async def run_scenario(self, scenario_def, test):
273276
274277 async def cleanup ():
275278 for t in self .targets .values ():
276- t .stop ()
279+ await t .stop ()
277280 for t in self .targets .values ():
278- t .join (5 )
281+ await t .join (5 )
279282 for conn in self .labels .values ():
280- await conn .aclose_conn (None )
283+ conn .close_conn (None )
281284
282285 self .addAsyncCleanup (cleanup )
283286
284287 try :
285288 if test ["error" ]:
286289 with self .assertRaises (PyMongoError ) as ctx :
287- self .run_operations (test ["operations" ])
290+ await self .run_operations (test ["operations" ])
288291 self .check_error (ctx .exception , test ["error" ])
289292 else :
290- self .run_operations (test ["operations" ])
293+ await self .run_operations (test ["operations" ])
291294
292295 self .check_events (test ["events" ], test ["ignore" ])
293296 except Exception :
@@ -452,8 +455,8 @@ async def test_close_leaves_pool_unpaused(self):
452455
453456
454457def create_test (scenario_def , test , name ):
455- def run_scenario (self ):
456- self .run_scenario (scenario_def , test )
458+ async def run_scenario (self ):
459+ await self .run_scenario (scenario_def , test )
457460
458461 return run_scenario
459462
@@ -468,9 +471,8 @@ async def tests(self, scenario_def):
468471 return [scenario_def ]
469472
470473
471- if _IS_SYNC :
472- test_creator = CMAPSpecTestCreator (create_test , AsyncTestCMAP , AsyncTestCMAP .TEST_PATH )
473- test_creator .create_tests ()
474+ test_creator = CMAPSpecTestCreator (create_test , AsyncTestCMAP , AsyncTestCMAP .TEST_PATH )
475+ test_creator .create_tests ()
474476
475477
476478if __name__ == "__main__" :
0 commit comments