From 8118aea985f017457259bff78e64656232f08eb5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 08:29:12 -0400 Subject: [PATCH 01/35] =?UTF-8?q?PYTHON-4844=20-=20Skip=20async=20test=5Fe?= =?UTF-8?q?ncryption.AsyncTestSpec.test=5Flegacy=5Fti=E2=80=A6=20(#1914)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/asynchronous/test_encryption.py | 5 +++++ test/test_encryption.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index c3f6223384..3e52fb9e1b 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -693,6 +693,11 @@ def maybe_skip_scenario(self, test): self.skipTest("PYTHON-3706 flaky test on Windows/macOS") if "type=symbol" in desc: self.skipTest("PyMongo does not support the symbol type") + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and not _IS_SYNC + ): + self.skipTest("PYTHON-4844 flaky test on async") def setup_scenario(self, scenario_def): """Override a test's setup.""" diff --git a/test/test_encryption.py b/test/test_encryption.py index 43c85e2c5b..64aa7ebf50 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -691,6 +691,11 @@ def maybe_skip_scenario(self, test): self.skipTest("PYTHON-3706 flaky test on Windows/macOS") if "type=symbol" in desc: self.skipTest("PyMongo does not support the symbol type") + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and not _IS_SYNC + ): + self.skipTest("PYTHON-4844 flaky test on async") def setup_scenario(self, scenario_def): """Override a test's setup.""" From 3a662291e010cbed832c00aff8ffe7b43d470489 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 10:48:24 -0400 Subject: [PATCH 02/35] PYTHON-4700 - Convert CSFLE tests to async (#1907) --- .evergreen/run-tests.sh | 4 +- pymongo/asynchronous/encryption.py | 12 +- pymongo/network_layer.py | 27 +- pymongo/synchronous/encryption.py | 12 +- test/__init__.py | 11 +- test/asynchronous/__init__.py | 11 +- test/asynchronous/test_encryption.py | 257 +++++++++--------- test/asynchronous/utils_spec_runner.py | 172 +++++++++++- .../spec/legacy/timeoutMS.json | 4 +- test/test_connection_monitoring.py | 3 +- test/test_encryption.py | 255 +++++++++-------- test/test_server_selection_in_window.py | 2 +- test/utils.py | 147 ---------- test/utils_spec_runner.py | 170 +++++++++++- tools/synchro.py | 2 + 15 files changed, 655 insertions(+), 434 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 8d7a9f082a..5e8429dd28 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html if [ -z "$TEST_SUITES" ]; then - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS else - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS fi else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 9b00c13e10..735e543047 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -180,10 +180,20 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + async_receive_data_socket, + ) + + data = await async_receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 4b57620d83..d14a21f41d 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None: loop.remove_writer(fd) async def _async_receive_ssl( - conn: _sslConn, length: int, loop: AbstractEventLoop + conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 @@ -145,6 +145,9 @@ def _is_ready(fut: Future) -> None: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + # KMS responses update their expected size after the first batch, stop reading after one loop + if once: + return mv[:read] total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() @@ -275,6 +278,28 @@ async def async_receive_data( sock.settimeout(sock_timeout) +async def async_receive_data_socket( + sock: Union[socket.socket, _sslConn], length: int +) -> memoryview: + sock_timeout = sock.gettimeout() + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + return await asyncio.wait_for( + _async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + timeout=timeout, + ) + else: + return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as err: + raise socket.timeout("timed out") from err + finally: + sock.settimeout(sock_timeout) + + async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index efef6df9e8..506ff8bcba 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -180,10 +180,20 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + receive_data_socket, + ) + + data = receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: diff --git a/test/__init__.py b/test/__init__.py index af12bc032a..fd33fde293 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -464,11 +464,12 @@ def wrap(*args, **kwargs): if not self.connected: pair = self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and condition(): - if wraps_async: - return f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if condition(): + if wraps_async: + return f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return f(*args, **kwargs) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 2a44785b2f..0579828c49 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -466,11 +466,12 @@ async def wrap(*args, **kwargs): if not self.connected: pair = await self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and await condition(): - if wraps_async: - return await f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if await condition(): + if wraps_async: + return await f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return await f(*args, **kwargs) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 3e52fb9e1b..88b005c4b3 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -30,6 +30,7 @@ import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -59,7 +60,6 @@ from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, async_wait_until, camel_to_snake_args, @@ -626,137 +626,132 @@ async def test_with_statement(self): KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(AsyncSpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() - - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) - - opts = dict(opts) - return AutoEncryptionOpts(**opts) - - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - - return super().parse_client_options(opts) - - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") - - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and not _IS_SYNC - ): - self.skipTest("PYTHON-4844 flaky test on async") - - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[ - "datakeys" - ] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) - - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = async_client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors - - def create_test(scenario_def, test, name): - @async_client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - - return run_scenario - - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() - - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) +class AsyncTestSpec(AsyncSpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + async def _setup_class(cls): + await super()._setup_class() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) + + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + + return super().parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC: + self.skipTest("PYTHON-4844 flaky test on async") + + async def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + await coll.delete_many({}) + if key_vault_data: + await coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = async_client_context.client.get_database(db_name, codec_options=OPTS) + await db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + await db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + await coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors + + +async def create_test(scenario_def, test, name): + @async_client_context.require_test_commands + async def run_scenario(self): + await self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 12cb13c2cd..4d9c4c8f20 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -32,11 +37,12 @@ ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.asynchronous.grid_file import AsyncGridFSBucket from pymongo.asynchronous import client_session from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor @@ -83,6 +89,161 @@ def run(self): self.stop() +class AsyncSpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = async_client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = async_client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + async def valid_topology(run_on_req): + return await async_client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return async_client_context.auth_enabled + return not async_client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return async_client_context.serverless + elif serverless == "forbid": + return not async_client_context.serverless + else: # unset or "allow" + return True + + async def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + await self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + async def predicate(): + return await self.should_run_on(scenario_def) + + return async_client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + async def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class AsyncSpecRunner(AsyncIntegrationTest): mongos_clients: List knobs: client_knobs @@ -284,7 +445,7 @@ async def run_operation(self, sessions, collection, operation): if object_name == "gridfsbucket": # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) + obj = AsyncGridFSBucket(database, bucket_name=collection.name) else: objects = { "client": database.client, @@ -312,7 +473,10 @@ async def run_operation(self, sessions, collection, operation): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = await cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addAsyncCleanup(result.close) @@ -588,7 +752,7 @@ async def run_scenario(self, scenario_def, test): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/test/client-side-encryption/spec/legacy/timeoutMS.json b/test/client-side-encryption/spec/legacy/timeoutMS.json index b667767cfc..8411306224 100644 --- a/test/client-side-encryption/spec/legacy/timeoutMS.json +++ b/test/client-side-encryption/spec/legacy/timeoutMS.json @@ -110,7 +110,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 60 + "blockTimeMS": 600 } }, "clientOptions": { @@ -119,7 +119,7 @@ "aws": {} } }, - "timeoutMS": 50 + "timeoutMS": 500 }, "operations": [ { diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 142af0f9a7..d576a1184a 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -25,14 +25,13 @@ from test.pymongo_mocks import DummyMonitor from test.utils import ( CMAPListener, - SpecTestCreator, camel_to_snake, client_context, get_pool, get_pools, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread +from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator from bson.objectid import ObjectId from bson.son import SON diff --git a/test/test_encryption.py b/test/test_encryption.py index 64aa7ebf50..13a69ca9ad 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -30,6 +30,7 @@ import warnings from test import IntegrationTest, PyMongoTestCase, client_context from test.test_bulk import BulkTestBase +from test.utils_spec_runner import SpecRunner, SpecTestCreator from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -58,7 +59,6 @@ from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, camel_to_snake_args, is_greenthread_patched, @@ -624,135 +624,132 @@ def test_with_statement(self): KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() - - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) - - opts = dict(opts) - return AutoEncryptionOpts(**opts) - - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - - return super().parse_client_options(opts) - - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") - - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and not _IS_SYNC - ): - self.skipTest("PYTHON-4844 flaky test on async") - - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) - - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors - - def create_test(scenario_def, test, name): - @client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - - return run_scenario - - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() - - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) +class TestSpec(SpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + def _setup_class(cls): + super()._setup_class() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) + + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + + return super().parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC: + self.skipTest("PYTHON-4844 flaky test on async") + + def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + coll.delete_many({}) + if key_vault_data: + coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = client_context.client.get_database(db_name, codec_options=OPTS) + db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors + + +def create_test(scenario_def, test, name): + @client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 7cab42cca2..05772fa385 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -21,11 +21,11 @@ from test.utils import ( CMAPListener, OvertCommandListener, - SpecTestCreator, get_pool, wait_until, ) from test.utils_selection_tests import create_topology +from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node from pymongo.monitoring import ConnectionReadyEvent diff --git a/test/utils.py b/test/utils.py index 9c78cff3ad..4575a9fe10 100644 --- a/test/utils.py +++ b/test/utils.py @@ -418,153 +418,6 @@ def call_count(self): return len(self._call_list) -class SpecTestCreator: - """Class to create test cases from specifications.""" - - def __init__(self, create_test, test_class, test_path): - """Create a TestCreator object. - - :Parameters: - - `create_test`: callback that returns a test case. The callback - must accept the following arguments - a dictionary containing the - entire test specification (the `scenario_def`), a dictionary - containing the specification for which the test case will be - generated (the `test_def`). - - `test_class`: the unittest.TestCase class in which to create the - test case. - - `test_path`: path to the directory containing the JSON files with - the test specifications. - """ - self._create_test = create_test - self._test_class = test_class - self.test_path = test_path - - def _ensure_min_max_server_version(self, scenario_def, method): - """Test modifier that enforces a version range for the server on a - test case. - """ - if "minServerVersion" in scenario_def: - min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) - if min_ver is not None: - method = client_context.require_version_min(*min_ver)(method) - - if "maxServerVersion" in scenario_def: - max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) - if max_ver is not None: - method = client_context.require_version_max(*max_ver)(method) - - if "serverless" in scenario_def: - serverless = scenario_def["serverless"] - if serverless == "require": - serverless_satisfied = client_context.serverless - elif serverless == "forbid": - serverless_satisfied = not client_context.serverless - else: # unset or "allow" - serverless_satisfied = True - method = unittest.skipUnless( - serverless_satisfied, "Serverless requirement not satisfied" - )(method) - - return method - - @staticmethod - def valid_topology(run_on_req): - return client_context.is_topology_type( - run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) - ) - - @staticmethod - def min_server_version(run_on_req): - version = run_on_req.get("minServerVersion") - if version: - min_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version >= min_ver - return True - - @staticmethod - def max_server_version(run_on_req): - version = run_on_req.get("maxServerVersion") - if version: - max_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version <= max_ver - return True - - @staticmethod - def valid_auth_enabled(run_on_req): - if "authEnabled" in run_on_req: - if run_on_req["authEnabled"]: - return client_context.auth_enabled - return not client_context.auth_enabled - return True - - @staticmethod - def serverless_ok(run_on_req): - serverless = run_on_req["serverless"] - if serverless == "require": - return client_context.serverless - elif serverless == "forbid": - return not client_context.serverless - else: # unset or "allow" - return True - - def should_run_on(self, scenario_def): - run_on = scenario_def.get("runOn", []) - if not run_on: - # Always run these tests. - return True - - for req in run_on: - if ( - self.valid_topology(req) - and self.min_server_version(req) - and self.max_server_version(req) - and self.valid_auth_enabled(req) - and self.serverless_ok(req) - ): - return True - return False - - def ensure_run_on(self, scenario_def, method): - """Test modifier that enforces a 'runOn' on a test case.""" - return client_context._require( - lambda: self.should_run_on(scenario_def), "runOn not satisfied", method - ) - - def tests(self, scenario_def): - """Allow CMAP spec test to override the location of test.""" - return scenario_def["tests"] - - def create_tests(self): - for dirpath, _, filenames in os.walk(self.test_path): - dirname = os.path.split(dirpath)[-1] - - for filename in filenames: - with open(os.path.join(dirpath, filename)) as scenario_stream: - # Use tz_aware=False to match how CodecOptions decodes - # dates. - opts = json_util.JSONOptions(tz_aware=False) - scenario_def = ScenarioDict( - json_util.loads(scenario_stream.read(), json_options=opts) - ) - - test_type = os.path.splitext(filename)[0] - - # Construct test from scenario. - for test_def in self.tests(scenario_def): - test_name = "test_{}_{}_{}".format( - dirname, - test_type.replace("-", "_").replace(".", "_"), - str(test_def["description"].replace(" ", "_").replace(".", "_")), - ) - - new_test = self._create_test(scenario_def, test_def, test_name) - new_test = self._ensure_min_max_server_version(scenario_def, new_test) - new_test = self.ensure_run_on(scenario_def, new_test) - - new_test.__name__ = test_name - setattr(self._test_class, new_test.__name__, new_test) - - def ensure_all_connected(client: MongoClient) -> None: """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 06a40351cd..8a061de0b1 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -32,11 +37,12 @@ ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -83,6 +89,161 @@ def run(self): self.stop() +class SpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + def valid_topology(run_on_req): + return client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return client_context.auth_enabled + return not client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return client_context.serverless + elif serverless == "forbid": + return not client_context.serverless + else: # unset or "allow" + return True + + def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + def predicate(): + return self.should_run_on(scenario_def) + + return client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class SpecRunner(IntegrationTest): mongos_clients: List knobs: client_knobs @@ -312,7 +473,10 @@ def run_operation(self, sessions, collection, operation): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addCleanup(result.close) @@ -583,7 +747,7 @@ def run_scenario(self, scenario_def, test): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/tools/synchro.py b/tools/synchro.py index 0ec8985a05..f704919a17 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -105,6 +105,8 @@ "PyMongo|c|async": "PyMongo|c", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", + "AsyncTestSpec": "TestSpec", + "AsyncSpecTestCreator": "SpecTestCreator", "async_set_fail_point": "set_fail_point", "async_ensure_all_connected": "ensure_all_connected", "async_repl_set_step_down": "repl_set_step_down", From 6973d2d2743b7679080b8be70391b767740cf674 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 11:02:06 -0400 Subject: [PATCH 03/35] PYTHON-4528 - Convert unified test runner to async (#1913) --- test/asynchronous/unified_format.py | 1573 +++++++++++++++++++++++++++ test/unified_format.py | 711 +----------- test/unified_format_shared.py | 679 ++++++++++++ tools/synchro.py | 1 + 4 files changed, 2301 insertions(+), 663 deletions(-) create mode 100644 test/asynchronous/unified_format.py create mode 100644 test/unified_format_shared.py diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py new file mode 100644 index 0000000000..4c37422951 --- /dev/null +++ b/test/asynchronous/unified_format.py @@ -0,0 +1,1573 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified test format runner. + +https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst +""" +from __future__ import annotations + +import asyncio +import binascii +import copy +import functools +import os +import re +import sys +import time +import traceback +from asyncio import iscoroutinefunction +from collections import defaultdict +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + client_knobs, + unittest, +) +from test.unified_format_shared import ( + IS_INTERRUPTED, + KMS_TLS_OPTS, + PLACEHOLDER_MAP, + SKIP_CSOT_TESTS, + EventListenerUtil, + MatchEvaluatorUtil, + coerce_result, + parse_bulk_write_error_result, + parse_bulk_write_result, + parse_client_bulk_write_error_result, + parse_collection_or_database_options, + with_metaclass, +) +from test.utils import ( + async_get_pool, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, + snake_to_camel, + wait_until, +) +from test.utils_spec_runner import SpecRunnerThread +from test.version import Version +from typing import Any, Dict, List, Mapping, Optional + +import pymongo +from bson import SON, json_util +from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.objectid import ObjectId +from gridfs import AsyncGridFSBucket, GridOut +from pymongo import ASCENDING, AsyncMongoClient, CursorType, _csot +from pymongo.asynchronous.change_stream import AsyncChangeStream +from pymongo.asynchronous.client_session import AsyncClientSession, TransactionOptions, _TxnState +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.encryption import AsyncClientEncryption +from pymongo.asynchronous.helpers import anext +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT +from pymongo.errors import ( + BulkWriteError, + ClientBulkWriteException, + ConfigurationError, + ConnectionFailure, + EncryptionError, + InvalidOperation, + NotPrimaryError, + OperationFailure, + PyMongoError, +) +from pymongo.monitoring import ( + CommandStartedEvent, +) +from pymongo.operations import ( + SearchIndexModel, +) +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import ServerApi +from pymongo.server_selectors import Selection, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.topology_description import TopologyDescription +from pymongo.typings import _Address +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +async def is_run_on_requirement_satisfied(requirement): + topology_satisfied = True + req_topologies = requirement.get("topologies") + if req_topologies: + topology_satisfied = await async_client_context.is_topology_type(req_topologies) + + server_version = Version(*async_client_context.version[:3]) + + min_version_satisfied = True + req_min_server_version = requirement.get("minServerVersion") + if req_min_server_version: + min_version_satisfied = Version.from_string(req_min_server_version) <= server_version + + max_version_satisfied = True + req_max_server_version = requirement.get("maxServerVersion") + if req_max_server_version: + max_version_satisfied = Version.from_string(req_max_server_version) >= server_version + + serverless = requirement.get("serverless") + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + + params_satisfied = True + params = requirement.get("serverParameters") + if params: + for param, val in params.items(): + if param not in async_client_context.server_parameters: + params_satisfied = False + elif async_client_context.server_parameters[param] != val: + params_satisfied = False + + auth_satisfied = True + req_auth = requirement.get("auth") + if req_auth is not None: + if req_auth: + auth_satisfied = async_client_context.auth_enabled + if auth_satisfied and "authMechanism" in requirement: + auth_satisfied = async_client_context.check_auth_type(requirement["authMechanism"]) + else: + auth_satisfied = not async_client_context.auth_enabled + + csfle_satisfied = True + req_csfle = requirement.get("csfle") + if req_csfle is True: + min_version_satisfied = Version.from_string("4.2") <= server_version + csfle_satisfied = _HAVE_PYMONGOCRYPT and min_version_satisfied + + return ( + topology_satisfied + and min_version_satisfied + and max_version_satisfied + and serverless_satisfied + and params_satisfied + and auth_satisfied + and csfle_satisfied + ) + + +class NonLazyCursor: + """A find cursor proxy that creates the remote cursor when initialized.""" + + def __init__(self, find_cursor, client): + self.client = client + self.find_cursor = find_cursor + # Create the server side cursor. + self.first_result = None + + @classmethod + async def create(cls, find_cursor, client): + cursor = cls(find_cursor, client) + try: + cursor.first_result = await anext(cursor.find_cursor) + except StopAsyncIteration: + cursor.first_result = None + return cursor + + @property + def alive(self): + return self.first_result is not None or self.find_cursor.alive + + async def __anext__(self): + if self.first_result is not None: + first = self.first_result + self.first_result = None + return first + return await anext(self.find_cursor) + + # Added to support the iterateOnce operation. + try_next = __anext__ + + async def close(self): + await self.find_cursor.close() + self.client = None + + +class EntityMapUtil: + """Utility class that implements an entity map as per the unified + test format specification. + """ + + def __init__(self, test_class): + self._entities: Dict[str, Any] = {} + self._listeners: Dict[str, EventListenerUtil] = {} + self._session_lsids: Dict[str, Mapping[str, Any]] = {} + self.test: UnifiedSpecTestMixinV1 = test_class + self._cluster_time: Mapping[str, Any] = {} + + def __contains__(self, item): + return item in self._entities + + def __len__(self): + return len(self._entities) + + def __getitem__(self, item): + try: + return self._entities[item] + except KeyError: + self.test.fail(f"Could not find entity named {item} in map") + + def __setitem__(self, key, value): + if not isinstance(key, str): + self.test.fail("Expected entity name of type str, got %s" % (type(key))) + + if key in self._entities: + self.test.fail(f"Entity named {key} already in map") + + self._entities[key] = value + + def _handle_placeholders(self, spec: dict, current: dict, path: str) -> Any: + if "$$placeholder" in current: + if path not in PLACEHOLDER_MAP: + raise ValueError(f"Could not find a placeholder value for {path}") + return PLACEHOLDER_MAP[path] + + for key in list(current): + value = current[key] + if isinstance(value, dict): + subpath = f"{path}/{key}" + current[key] = self._handle_placeholders(spec, value, subpath) + return current + + async def _create_entity(self, entity_spec, uri=None): + if len(entity_spec) != 1: + self.test.fail(f"Entity spec {entity_spec} did not contain exactly one top-level key") + + entity_type, spec = next(iter(entity_spec.items())) + spec = self._handle_placeholders(spec, spec, "") + if entity_type == "client": + kwargs: dict = {} + observe_events = spec.get("observeEvents", []) + + # The unified tests use topologyOpeningEvent, we use topologyOpenedEvent + for i in range(len(observe_events)): + if "topologyOpeningEvent" == observe_events[i]: + observe_events[i] = "topologyOpenedEvent" + ignore_commands = spec.get("ignoreCommandMonitoringEvents", []) + observe_sensitive_commands = spec.get("observeSensitiveCommands", False) + ignore_commands = [cmd.lower() for cmd in ignore_commands] + listener = EventListenerUtil( + observe_events, + ignore_commands, + observe_sensitive_commands, + spec.get("storeEventsAsEntities"), + self, + ) + self._listeners[spec["id"]] = listener + kwargs["event_listeners"] = [listener] + if spec.get("useMultipleMongoses"): + if async_client_context.load_balancer or async_client_context.serverless: + kwargs["h"] = async_client_context.MULTI_MONGOS_LB_URI + elif async_client_context.is_mongos: + kwargs["h"] = async_client_context.mongos_seeds() + kwargs.update(spec.get("uriOptions", {})) + server_api = spec.get("serverApi") + if "waitQueueSize" in kwargs: + raise unittest.SkipTest("PyMongo does not support waitQueueSize") + if "waitQueueMultiple" in kwargs: + raise unittest.SkipTest("PyMongo does not support waitQueueMultiple") + if server_api: + kwargs["server_api"] = ServerApi( + server_api["version"], + strict=server_api.get("strict"), + deprecation_errors=server_api.get("deprecationErrors"), + ) + if uri: + kwargs["h"] = uri + client = await self.test.async_rs_or_single_client(**kwargs) + self[spec["id"]] = client + self.test.addAsyncCleanup(client.close) + return + elif entity_type == "database": + client = self[spec["client"]] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + "Expected entity {} to be of type AsyncMongoClient, got {}".format( + spec["client"], type(client) + ) + ) + options = parse_collection_or_database_options(spec.get("databaseOptions", {})) + self[spec["id"]] = client.get_database(spec["databaseName"], **options) + return + elif entity_type == "collection": + database = self[spec["database"]] + if not isinstance(database, AsyncDatabase): + self.test.fail( + "Expected entity {} to be of type AsyncDatabase, got {}".format( + spec["database"], type(database) + ) + ) + options = parse_collection_or_database_options(spec.get("collectionOptions", {})) + self[spec["id"]] = database.get_collection(spec["collectionName"], **options) + return + elif entity_type == "session": + client = self[spec["client"]] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + "Expected entity {} to be of type AsyncMongoClient, got {}".format( + spec["client"], type(client) + ) + ) + opts = camel_to_snake_args(spec.get("sessionOptions", {})) + if "default_transaction_options" in opts: + txn_opts = parse_spec_options(opts["default_transaction_options"]) + txn_opts = TransactionOptions(**txn_opts) + opts = copy.deepcopy(opts) + opts["default_transaction_options"] = txn_opts + session = client.start_session(**dict(opts)) + self[spec["id"]] = session + self._session_lsids[spec["id"]] = copy.deepcopy(session.session_id) + self.test.addAsyncCleanup(session.end_session) + return + elif entity_type == "bucket": + db = self[spec["database"]] + kwargs = parse_spec_options(spec.get("bucketOptions", {}).copy()) + bucket = AsyncGridFSBucket(db, **kwargs) + + # PyMongo does not support AsyncGridFSBucket.drop(), emulate it. + @_csot.apply + async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: + await self._files.drop(*args, **kwargs) + await self._chunks.drop(*args, **kwargs) + + if not hasattr(bucket, "drop"): + bucket.drop = drop.__get__(bucket) + self[spec["id"]] = bucket + return + elif entity_type == "clientEncryption": + opts = camel_to_snake_args(spec["clientEncryptionOpts"].copy()) + if isinstance(opts["key_vault_client"], str): + opts["key_vault_client"] = self[opts["key_vault_client"]] + # Set TLS options for providers like "kmip:name1". + kms_tls_options = {} + for provider in opts["kms_providers"]: + provider_type = provider.split(":")[0] + if provider_type in KMS_TLS_OPTS: + kms_tls_options[provider] = KMS_TLS_OPTS[provider_type] + self[spec["id"]] = AsyncClientEncryption( + opts["kms_providers"], + opts["key_vault_namespace"], + opts["key_vault_client"], + DEFAULT_CODEC_OPTIONS, + opts.get("kms_tls_options", kms_tls_options), + ) + return + elif entity_type == "thread": + name = spec["id"] + thread = SpecRunnerThread(name) + thread.start() + self[name] = thread + return + + self.test.fail(f"Unable to create entity of unknown type {entity_type}") + + async def create_entities_from_spec(self, entity_spec, uri=None): + for spec in entity_spec: + await self._create_entity(spec, uri=uri) + + def get_listener_for_client(self, client_name: str) -> EventListenerUtil: + client = self[client_name] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + f"Expected entity {client_name} to be of type AsyncMongoClient, got {type(client)}" + ) + + listener = self._listeners.get(client_name) + if not listener: + self.test.fail(f"No listeners configured for client {client_name}") + + return listener + + def get_lsid_for_session(self, session_name): + session = self[session_name] + if not isinstance(session, AsyncClientSession): + self.test.fail( + f"Expected entity {session_name} to be of type AsyncClientSession, got {type(session)}" + ) + + try: + return session.session_id + except InvalidOperation: + # session has been closed. + return self._session_lsids[session_name] + + async def advance_cluster_times(self) -> None: + """Manually synchronize entities when desired""" + if not self._cluster_time: + self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime") + for entity in self._entities.values(): + if isinstance(entity, AsyncClientSession) and self._cluster_time: + entity.advance_cluster_time(self._cluster_time) + + +class UnifiedSpecTestMixinV1(AsyncIntegrationTest): + """Mixin class to run test cases from test specification files. + + Assumes that tests conform to the `unified test format + `_. + + Specification of the test suite being currently run is available as + a class attribute ``TEST_SPEC``. + """ + + SCHEMA_VERSION = Version.from_string("1.21") + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + TEST_SPEC: Any + mongos_clients: list[AsyncMongoClient] = [] + + @staticmethod + async def should_run_on(run_on_spec): + if not run_on_spec: + # Always run these tests. + return True + + for req in run_on_spec: + if await is_run_on_requirement_satisfied(req): + return True + return False + + async def insert_initial_data(self, initial_data): + for i, collection_data in enumerate(initial_data): + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + opts = collection_data.get("createOptions", {}) + documents = collection_data["documents"] + + # Setup the collection with as few majority writes as possible. + db = self.client[db_name] + await db.drop_collection(coll_name) + # Only use majority wc only on the final write. + if i == len(initial_data) - 1: + wc = WriteConcern(w="majority") + else: + wc = WriteConcern(w=1) + if documents: + if opts: + await db.create_collection(coll_name, **opts) + await db.get_collection(coll_name, write_concern=wc).insert_many(documents) + else: + # Ensure collection exists + await db.create_collection(coll_name, write_concern=wc, **opts) + + @classmethod + async def _setup_class(cls): + # super call creates internal client cls.client + await super()._setup_class() + # process file-level runOnRequirements + run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) + if not await cls.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( + cls.TEST_PATH + ): + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") + + # Handle mongos_clients for transactions tests. + cls.mongos_clients = [] + if ( + async_client_context.supports_transactions() + and not async_client_context.load_balancer + and not async_client_context.serverless + ): + for address in async_client_context.mongoses: + cls.mongos_clients.append( + await cls.unmanaged_async_single_client("{}:{}".format(*address)) + ) + + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + async def _tearDown_class(cls): + cls.knobs.disable() + for client in cls.mongos_clients: + await client.close() + await super()._tearDown_class() + + async def asyncSetUp(self): + await super().asyncSetUp() + # process schemaVersion + # note: we check major schema version during class generation + # note: we do this here because we cannot run assertions in setUpClass + version = Version.from_string(self.TEST_SPEC["schemaVersion"]) + self.assertLessEqual( + version, + self.SCHEMA_VERSION, + f"expected schema version {self.SCHEMA_VERSION} or lower, got {version}", + ) + + # initialize internals + self.match_evaluator = MatchEvaluatorUtil(self) + + def maybe_skip_test(self, spec): + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if ( + "Dirty explicit session is discarded" in spec["description"] + or "Dirty implicit session is discarded" in spec["description"] + or "Cancel server check" in spec["description"] + ): + self.skipTest("MMAPv1 does not support retryWrites=True") + if ( + "AsyncDatabase-level aggregate with $out includes read preference for 5.0+ server" + in spec["description"] + ): + if async_client_context.version[0] == 8: + self.skipTest("waiting on PYTHON-4356") + if "Aggregate with $out includes read preference for 5.0+ server" in spec["description"]: + if async_client_context.version[0] == 8: + self.skipTest("waiting on PYTHON-4356") + if "Client side error in command starting transaction" in spec["description"]: + self.skipTest("Implement PYTHON-1894") + if "timeoutMS applied to entire download" in spec["description"]: + self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + + class_name = self.__class__.__name__.lower() + description = spec["description"].lower() + if "csot" in class_name: + if "gridfs" in class_name and sys.platform == "win32": + self.skipTest("PYTHON-3522 CSOT GridFS tests are flaky on Windows") + if async_client_context.storage_engine == "mmapv1": + self.skipTest( + "MMAPv1 does not support retryable writes which is required for CSOT tests" + ) + if "change" in description or "change" in class_name: + self.skipTest("CSOT not implemented for watch()") + if "cursors" in class_name: + self.skipTest("CSOT not implemented for cursors") + if "tailable" in class_name: + self.skipTest("CSOT not implemented for tailable cursors") + if "sessions" in class_name: + self.skipTest("CSOT not implemented for sessions") + if "withtransaction" in description: + self.skipTest("CSOT not implemented for with_transaction") + if "transaction" in class_name or "transaction" in description: + self.skipTest("CSOT not implemented for transactions") + + # Some tests need to be skipped based on the operations they try to run. + for op in spec["operations"]: + name = op["name"] + if name == "count": + self.skipTest("PyMongo does not support count()") + if name == "listIndexNames": + self.skipTest("PyMongo does not support list_index_names()") + if async_client_context.storage_engine == "mmapv1": + if name == "createChangeStream": + self.skipTest("MMAPv1 does not support change streams") + if name == "withTransaction" or name == "startTransaction": + self.skipTest("MMAPv1 does not support document-level locking") + if not async_client_context.test_commands_enabled: + if name == "failPoint" or name == "targetedFailPoint": + self.skipTest("Test commands must be enabled to use fail points") + if name == "modifyCollection": + self.skipTest("PyMongo does not support modifyCollection") + if "timeoutMode" in op.get("arguments", {}): + self.skipTest("PyMongo does not support timeoutMode") + + def process_error(self, exception, spec): + if isinstance(exception, unittest.SkipTest): + raise + is_error = spec.get("isError") + is_client_error = spec.get("isClientError") + is_timeout_error = spec.get("isTimeoutError") + error_contains = spec.get("errorContains") + error_code = spec.get("errorCode") + error_code_name = spec.get("errorCodeName") + error_labels_contain = spec.get("errorLabelsContain") + error_labels_omit = spec.get("errorLabelsOmit") + expect_result = spec.get("expectResult") + error_response = spec.get("errorResponse") + if error_response: + if isinstance(exception, ClientBulkWriteException): + self.match_evaluator.match_result(error_response, exception.error.details) + else: + self.match_evaluator.match_result(error_response, exception.details) + + if is_error: + # already satisfied because exception was raised + pass + + if is_client_error: + if isinstance(exception, ClientBulkWriteException): + error = exception.error + else: + error = exception + # Connection errors are considered client errors. + if isinstance(error, ConnectionFailure): + self.assertNotIsInstance(error, NotPrimaryError) + elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)): + pass + else: + self.assertNotIsInstance(error, PyMongoError) + + if is_timeout_error: + self.assertIsInstance(exception, PyMongoError) + if not exception.timeout: + # Re-raise the exception for better diagnostics. + raise exception + + if error_contains: + if isinstance(exception, BulkWriteError): + errmsg = str(exception.details).lower() + elif isinstance(exception, ClientBulkWriteException): + errmsg = str(exception.details).lower() + else: + errmsg = str(exception).lower() + self.assertIn(error_contains.lower(), errmsg) + + if error_code: + if isinstance(exception, ClientBulkWriteException): + self.assertEqual(error_code, exception.error.details.get("code")) + else: + self.assertEqual(error_code, exception.details.get("code")) + + if error_code_name: + if isinstance(exception, ClientBulkWriteException): + self.assertEqual(error_code, exception.error.details.get("codeName")) + else: + self.assertEqual(error_code_name, exception.details.get("codeName")) + + if error_labels_contain: + if isinstance(exception, ClientBulkWriteException): + error = exception.error + else: + error = exception + labels = [ + err_label for err_label in error_labels_contain if error.has_error_label(err_label) + ] + self.assertEqual(labels, error_labels_contain) + + if error_labels_omit: + for err_label in error_labels_omit: + if exception.has_error_label(err_label): + self.fail(f"Exception '{exception}' unexpectedly had label '{err_label}'") + + if expect_result: + if isinstance(exception, BulkWriteError): + result = parse_bulk_write_error_result(exception) + self.match_evaluator.match_result(expect_result, result) + elif isinstance(exception, ClientBulkWriteException): + result = parse_client_bulk_write_error_result(exception) + self.match_evaluator.match_result(expect_result, result) + else: + self.fail( + f"expectResult can only be specified with {BulkWriteError} or {ClientBulkWriteException} exceptions" + ) + + return exception + + def __raise_if_unsupported(self, opname, target, *target_types): + if not isinstance(target, target_types): + self.fail(f"Operation {opname} not supported for entity of type {type(target)}") + + async def __entityOperation_createChangeStream(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support change streams") + self.__raise_if_unsupported( + "createChangeStream", target, AsyncMongoClient, AsyncDatabase, AsyncCollection + ) + stream = await target.watch(*args, **kwargs) + self.addAsyncCleanup(stream.close) + return stream + + async def _clientOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _databaseOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _collectionOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _databaseOperation_runCommand(self, target, **kwargs): + self.__raise_if_unsupported("runCommand", target, AsyncDatabase) + # Ensure the first key is the command name. + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command + return await target.command(**kwargs) + + async def _databaseOperation_runCursorCommand(self, target, **kwargs): + return list(await self._databaseOperation_createCommandCursor(target, **kwargs)) + + async def _databaseOperation_createCommandCursor(self, target, **kwargs): + self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase) + # Ensure the first key is the command name. + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command + batch_size = 0 + + cursor_type = kwargs.pop("cursor_type", "nonTailable") + if cursor_type == CursorType.TAILABLE: + ordered_command["tailable"] = True + elif cursor_type == CursorType.TAILABLE_AWAIT: + ordered_command["tailable"] = True + ordered_command["awaitData"] = True + elif cursor_type != "nonTailable": + self.fail(f"unknown cursorType: {cursor_type}") + + if "maxTimeMS" in kwargs: + kwargs["max_await_time_ms"] = kwargs.pop("maxTimeMS") + + if "batch_size" in kwargs: + batch_size = kwargs.pop("batch_size") + + cursor = await target.cursor_command(**kwargs) + + if batch_size > 0: + cursor.batch_size(batch_size) + + return cursor + + async def kill_all_sessions(self): + if getattr(self, "client", None) is None: + return + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + try: + await client.admin.command("killAllSessions", []) + except OperationFailure: + # "operation was interrupted" by killing the command's + # own session. + pass + + async def _databaseOperation_listCollections(self, target, *args, **kwargs): + if "batch_size" in kwargs: + kwargs["cursor"] = {"batchSize": kwargs.pop("batch_size")} + cursor = await target.list_collections(*args, **kwargs) + return list(cursor) + + async def _databaseOperation_createCollection(self, target, *args, **kwargs): + # PYTHON-1936 Ignore the listCollections event from create_collection. + kwargs["check_exists"] = False + ret = await target.create_collection(*args, **kwargs) + return ret + + async def __entityOperation_aggregate(self, target, *args, **kwargs): + self.__raise_if_unsupported("aggregate", target, AsyncDatabase, AsyncCollection) + return await (await target.aggregate(*args, **kwargs)).to_list() + + async def _databaseOperation_aggregate(self, target, *args, **kwargs): + return await self.__entityOperation_aggregate(target, *args, **kwargs) + + async def _collectionOperation_aggregate(self, target, *args, **kwargs): + return await self.__entityOperation_aggregate(target, *args, **kwargs) + + async def _collectionOperation_find(self, target, *args, **kwargs): + self.__raise_if_unsupported("find", target, AsyncCollection) + find_cursor = target.find(*args, **kwargs) + return await find_cursor.to_list() + + async def _collectionOperation_createFindCursor(self, target, *args, **kwargs): + self.__raise_if_unsupported("find", target, AsyncCollection) + if "filter" not in kwargs: + self.fail('createFindCursor requires a "filter" argument') + cursor = await NonLazyCursor.create(target.find(*args, **kwargs), target.database.client) + self.addAsyncCleanup(cursor.close) + return cursor + + def _collectionOperation_count(self, target, *args, **kwargs): + self.skipTest("PyMongo does not support collection.count()") + + async def _collectionOperation_listIndexes(self, target, *args, **kwargs): + if "batch_size" in kwargs: + self.skipTest("PyMongo does not support batch_size for list_indexes") + return await (await target.list_indexes(*args, **kwargs)).to_list() + + def _collectionOperation_listIndexNames(self, target, *args, **kwargs): + self.skipTest("PyMongo does not support list_index_names") + + async def _collectionOperation_createSearchIndexes(self, target, *args, **kwargs): + models = [SearchIndexModel(**i) for i in kwargs["models"]] + return await target.create_search_indexes(models) + + async def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs): + name = kwargs.get("name") + agg_kwargs = kwargs.get("aggregation_options", dict()) + return await (await target.list_search_indexes(name, **agg_kwargs)).to_list() + + async def _sessionOperation_withTransaction(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("withTransaction", target, AsyncClientSession) + return await target.with_transaction(*args, **kwargs) + + async def _sessionOperation_startTransaction(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("startTransaction", target, AsyncClientSession) + return await target.start_transaction(*args, **kwargs) + + async def _changeStreamOperation_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported("iterateUntilDocumentOrError", target, AsyncChangeStream) + return await anext(target) + + async def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported( + "iterateUntilDocumentOrError", target, NonLazyCursor, AsyncCommandCursor + ) + while target.alive: + try: + return await anext(target) + except StopAsyncIteration: + pass + return None + + async def _cursor_close(self, target, *args, **kwargs): + self.__raise_if_unsupported("close", target, NonLazyCursor, AsyncCommandCursor) + return await target.close() + + async def _clientEncryptionOperation_createDataKey(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + + return await target.create_data_key(*args, **kwargs) + + async def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs): + return await (await target.get_keys(*args, **kwargs)).to_list() + + async def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs): + result = await target.delete_key(*args, **kwargs) + response = result.raw_result + response["deletedCount"] = result.deleted_count + return response + + async def _clientEncryptionOperation_rewrapManyDataKey(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + data = await target.rewrap_many_data_key(*args, **kwargs) + if data.bulk_write_result: + return {"bulkWriteResult": parse_bulk_write_result(data.bulk_write_result)} + return {} + + async def _clientEncryptionOperation_encrypt(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + return await target.encrypt(*args, **kwargs) + + async def _bucketOperation_download( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> bytes: + async with await target.open_download_stream(*args, **kwargs) as gout: + return await gout.read() + + async def _bucketOperation_downloadByName( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> bytes: + async with await target.open_download_stream_by_name(*args, **kwargs) as gout: + return await gout.read() + + async def _bucketOperation_upload( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> ObjectId: + kwargs["source"] = binascii.unhexlify(kwargs.pop("source")["$$hexBytes"]) + if "content_type" in kwargs: + kwargs.setdefault("metadata", {})["contentType"] = kwargs.pop("content_type") + return await target.upload_from_stream(*args, **kwargs) + + async def _bucketOperation_uploadWithId( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> Any: + kwargs["source"] = binascii.unhexlify(kwargs.pop("source")["$$hexBytes"]) + if "content_type" in kwargs: + kwargs.setdefault("metadata", {})["contentType"] = kwargs.pop("content_type") + return await target.upload_from_stream_with_id(*args, **kwargs) + + async def _bucketOperation_find( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> List[GridOut]: + return await target.find(*args, **kwargs).to_list() + + async def run_entity_operation(self, spec): + target = self.entity_map[spec["object"]] + opname = spec["name"] + opargs = spec.get("arguments") + expect_error = spec.get("expectError") + save_as_entity = spec.get("saveResultAsEntity") + expect_result = spec.get("expectResult") + ignore = spec.get("ignoreResultAndError") + if ignore and (expect_error or save_as_entity or expect_result): + raise ValueError( + "ignoreResultAndError is incompatible with saveResultAsEntity" + ", expectError, and expectResult" + ) + if opargs: + arguments = parse_spec_options(copy.deepcopy(opargs)) + prepare_spec_arguments( + spec, + arguments, + camel_to_snake(opname), + self.entity_map, + self.run_operations_and_throw, + ) + else: + arguments = {} + + if isinstance(target, AsyncMongoClient): + method_name = f"_clientOperation_{opname}" + elif isinstance(target, AsyncDatabase): + method_name = f"_databaseOperation_{opname}" + elif isinstance(target, AsyncCollection): + method_name = f"_collectionOperation_{opname}" + # contentType is always stored in metadata in pymongo. + if target.name.endswith(".files") and opname == "find": + for doc in spec.get("expectResult", []): + if "contentType" in doc: + doc.setdefault("metadata", {})["contentType"] = doc.pop("contentType") + elif isinstance(target, AsyncChangeStream): + method_name = f"_changeStreamOperation_{opname}" + elif isinstance(target, (NonLazyCursor, AsyncCommandCursor)): + method_name = f"_cursor_{opname}" + elif isinstance(target, AsyncClientSession): + method_name = f"_sessionOperation_{opname}" + elif isinstance(target, AsyncGridFSBucket): + method_name = f"_bucketOperation_{opname}" + if "id" in arguments: + arguments["file_id"] = arguments.pop("id") + # MD5 is always disabled in pymongo. + arguments.pop("disable_md5", None) + elif isinstance(target, AsyncClientEncryption): + method_name = f"_clientEncryptionOperation_{opname}" + else: + method_name = "doesNotExist" + + try: + method = getattr(self, method_name) + except AttributeError: + target_opname = camel_to_snake(opname) + if target_opname == "iterate_once": + target_opname = "try_next" + if target_opname == "client_bulk_write": + target_opname = "bulk_write" + try: + cmd = getattr(target, target_opname) + except AttributeError: + self.fail(f"Unsupported operation {opname} on entity {target}") + else: + cmd = functools.partial(method, target) + + try: + # CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API. + if "timeout" in arguments: + timeout = arguments.pop("timeout") + with pymongo.timeout(timeout): + result = await cmd(**dict(arguments)) + else: + result = await cmd(**dict(arguments)) + except Exception as exc: + # Ignore all operation errors but to avoid masking bugs don't + # ignore things like TypeError and ValueError. + if ignore and isinstance(exc, (PyMongoError,)): + return exc + if expect_error: + if method_name == "_collectionOperation_bulkWrite": + self.skipTest("Skipping test pending PYTHON-4598") + return self.process_error(exc, expect_error) + raise + else: + if method_name == "_collectionOperation_bulkWrite": + self.skipTest("Skipping test pending PYTHON-4598") + if expect_error: + self.fail(f'Excepted error {expect_error} but "{opname}" succeeded: {result}') + + if expect_result: + actual = coerce_result(opname, result) + self.match_evaluator.match_result(expect_result, actual) + + if save_as_entity: + self.entity_map[save_as_entity] = result + return None + return None + + async def __set_fail_point(self, client, command_args): + if not async_client_context.test_commands_enabled: + self.skipTest("Test commands must be enabled") + + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + await client.admin.command(cmd_on) + self.addAsyncCleanup( + client.admin.command, "configureFailPoint", cmd_on["configureFailPoint"], mode="off" + ) + + async def _testOperation_failPoint(self, spec): + await self.__set_fail_point( + client=self.entity_map[spec["client"]], command_args=spec["failPoint"] + ) + + async def _testOperation_targetedFailPoint(self, spec): + session = self.entity_map[spec["session"]] + if not session._pinned_address: + self.fail( + "Cannot use targetedFailPoint operation with unpinned " "session {}".format( + spec["session"] + ) + ) + + client = await self.async_single_client("{}:{}".format(*session._pinned_address)) + self.addAsyncCleanup(client.close) + await self.__set_fail_point(client=client, command_args=spec["failPoint"]) + + async def _testOperation_createEntities(self, spec): + await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) + await self.entity_map.advance_cluster_times() + + def _testOperation_assertSessionTransactionState(self, spec): + session = self.entity_map[spec["session"]] + expected_state = getattr(_TxnState, spec["state"].upper()) + self.assertEqual(expected_state, session._transaction.state) + + def _testOperation_assertSessionPinned(self, spec): + session = self.entity_map[spec["session"]] + self.assertIsNotNone(session._transaction.pinned_address) + + def _testOperation_assertSessionUnpinned(self, spec): + session = self.entity_map[spec["session"]] + self.assertIsNone(session._pinned_address) + self.assertIsNone(session._transaction.pinned_address) + + def __get_last_two_command_lsids(self, listener): + cmd_started_events = [] + for event in reversed(listener.events): + if isinstance(event, CommandStartedEvent): + cmd_started_events.append(event) + if len(cmd_started_events) < 2: + self.fail( + "Needed 2 CommandStartedEvents to compare lsids, " + "got %s" % (len(cmd_started_events)) + ) + return tuple([e.command["lsid"] for e in cmd_started_events][:2]) + + def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec): + listener = self.entity_map.get_listener_for_client(spec["client"]) + self.assertNotEqual(*self.__get_last_two_command_lsids(listener)) + + def _testOperation_assertSameLsidOnLastTwoCommands(self, spec): + listener = self.entity_map.get_listener_for_client(spec["client"]) + self.assertEqual(*self.__get_last_two_command_lsids(listener)) + + def _testOperation_assertSessionDirty(self, spec): + session = self.entity_map[spec["session"]] + self.assertTrue(session._server_session.dirty) + + def _testOperation_assertSessionNotDirty(self, spec): + session = self.entity_map[spec["session"]] + return self.assertFalse(session._server_session.dirty) + + async def _testOperation_assertCollectionExists(self, spec): + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list( + await self.client.get_database(database_name).list_collection_names() + ) + self.assertIn(collection_name, collection_name_list) + + async def _testOperation_assertCollectionNotExists(self, spec): + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list( + await self.client.get_database(database_name).list_collection_names() + ) + self.assertNotIn(collection_name, collection_name_list) + + async def _testOperation_assertIndexExists(self, spec): + collection = self.client[spec["databaseName"]][spec["collectionName"]] + index_names = [idx["name"] async for idx in await collection.list_indexes()] + self.assertIn(spec["indexName"], index_names) + + async def _testOperation_assertIndexNotExists(self, spec): + collection = self.client[spec["databaseName"]][spec["collectionName"]] + async for index in await collection.list_indexes(): + self.assertNotEqual(spec["indexName"], index["name"]) + + async def _testOperation_assertNumberConnectionsCheckedOut(self, spec): + client = self.entity_map[spec["client"]] + pool = await async_get_pool(client) + self.assertEqual(spec["connections"], pool.active_sockets) + + def _event_count(self, client_name, event): + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events("all") + count = 0 + for actual in actual_events: + try: + self.match_evaluator.match_event(event, actual) + except AssertionError: + continue + else: + count += 1 + return count + + def _testOperation_assertEventCount(self, spec): + """Run the assertEventCount test operation. + + Assert the given event was published exactly `count` times. + """ + client, event, count = spec["client"], spec["event"], spec["count"] + self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") + + def _testOperation_waitForEvent(self, spec): + """Run the waitForEvent test operation. + + Wait for a number of events to be published, or fail. + """ + client, event, count = spec["client"], spec["event"], spec["count"] + wait_until( + lambda: self._event_count(client, event) >= count, + f"find {count} {event} event(s)", + ) + + async def _testOperation_wait(self, spec): + """Run the "wait" test operation.""" + await asyncio.sleep(spec["ms"] / 1000.0) + + def _testOperation_recordTopologyDescription(self, spec): + """Run the recordTopologyDescription test operation.""" + self.entity_map[spec["id"]] = self.entity_map[spec["client"]].topology_description + + def _testOperation_assertTopologyType(self, spec): + """Run the assertTopologyType test operation.""" + description = self.entity_map[spec["topologyDescription"]] + self.assertIsInstance(description, TopologyDescription) + self.assertEqual(description.topology_type_name, spec["topologyType"]) + + def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: + """Run the waitForPrimaryChange test operation.""" + client = self.entity_map[spec["client"]] + old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]] + timeout = spec["timeoutMS"] / 1000.0 + + def get_primary(td: TopologyDescription) -> Optional[_Address]: + servers = writable_server_selector(Selection.from_topology_description(td)) + if servers and servers[0].server_type == SERVER_TYPE.RSPrimary: + return servers[0].address + return None + + old_primary = get_primary(old_description) + + def primary_changed() -> bool: + primary = client.primary + if primary is None: + return False + return primary != old_primary + + wait_until(primary_changed, "change primary", timeout=timeout) + + def _testOperation_runOnThread(self, spec): + """Run the 'runOnThread' operation.""" + thread = self.entity_map[spec["thread"]] + thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + + def _testOperation_waitForThread(self, spec): + """Run the 'waitForThread' operation.""" + thread = self.entity_map[spec["thread"]] + thread.stop() + thread.join(10) + if thread.exc: + raise thread.exc + self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) + + async def _testOperation_loop(self, spec): + failure_key = spec.get("storeFailuresAsEntity") + error_key = spec.get("storeErrorsAsEntity") + successes_key = spec.get("storeSuccessesAsEntity") + iteration_key = spec.get("storeIterationsAsEntity") + iteration_limiter_key = spec.get("numIterations") + for i in [failure_key, error_key]: + if i: + self.entity_map[i] = [] + for i in [successes_key, iteration_key]: + if i: + self.entity_map[i] = 0 + i = 0 + global IS_INTERRUPTED + while True: + if iteration_limiter_key and i >= iteration_limiter_key: + break + i += 1 + if IS_INTERRUPTED: + break + try: + if iteration_key: + self.entity_map._entities[iteration_key] += 1 + for op in spec["operations"]: + await self.run_entity_operation(op) + if successes_key: + self.entity_map._entities[successes_key] += 1 + except Exception as exc: + if isinstance(exc, AssertionError): + key = failure_key or error_key + else: + key = error_key or failure_key + if not key: + raise + self.entity_map[key].append( + {"error": str(exc), "time": time.time(), "type": type(exc).__name__} + ) + + async def run_special_operation(self, spec): + opname = spec["name"] + method_name = f"_testOperation_{opname}" + try: + method = getattr(self, method_name) + except AttributeError: + self.fail(f"Unsupported special test operation {opname}") + else: + if iscoroutinefunction(method): + await method(spec["arguments"]) + else: + method(spec["arguments"]) + + async def run_operations(self, spec): + for op in spec: + if op["object"] == "testRunner": + await self.run_special_operation(op) + else: + await self.run_entity_operation(op) + + async def run_operations_and_throw(self, spec): + for op in spec: + if op["object"] == "testRunner": + await self.run_special_operation(op) + else: + result = await self.run_entity_operation(op) + if isinstance(result, Exception): + raise result + + def check_events(self, spec): + for event_spec in spec: + client_name = event_spec["client"] + events = event_spec["events"] + event_type = event_spec.get("eventType", "command") + ignore_extra_events = event_spec.get("ignoreExtraEvents", False) + server_connection_id = event_spec.get("serverConnectionId") + has_server_connection_id = event_spec.get("hasServerConnectionId", False) + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events(event_type) + if ignore_extra_events: + actual_events = actual_events[: len(events)] + + if len(events) == 0: + self.assertEqual(actual_events, []) + continue + + if len(actual_events) != len(events): + expected = "\n".join(str(e) for e in events) + actual = "\n".join(str(a) for a in actual_events) + self.assertEqual( + len(actual_events), + len(events), + f"expected events:\n{expected}\nactual events:\n{actual}", + ) + + for idx, expected_event in enumerate(events): + self.match_evaluator.match_event(expected_event, actual_events[idx]) + + if has_server_connection_id: + assert server_connection_id is not None + assert server_connection_id >= 0 + else: + assert server_connection_id is None + + def process_ignore_messages(self, ignore_logs, actual_logs): + final_logs = [] + for log in actual_logs: + ignored = False + for ignore_log in ignore_logs: + if log["data"]["message"] == ignore_log["data"][ + "message" + ] and self.match_evaluator.match_result(ignore_log, log, test=False): + ignored = True + break + if not ignored: + final_logs.append(log) + return final_logs + + async def check_log_messages(self, operations, spec): + def format_logs(log_list): + client_to_log = defaultdict(list) + for log in log_list: + if log.module == "ocsp_support": + continue + data = json_util.loads(log.getMessage()) + client = data.pop("clientId") if "clientId" in data else data.pop("topologyId") + client_to_log[client].append( + { + "level": log.levelname.lower(), + "component": log.name.replace("pymongo.", "", 1), + "data": data, + } + ) + return client_to_log + + with self.assertLogs("pymongo", level="DEBUG") as cm: + await self.run_operations(operations) + formatted_logs = format_logs(cm.records) + for client in spec: + components = set() + for message in client["messages"]: + components.add(message["component"]) + + clientid = self.entity_map[client["client"]]._topology_settings._topology_id + actual_logs = formatted_logs[clientid] + actual_logs = [log for log in actual_logs if log["component"] in components] + + ignore_logs = client.get("ignoreMessages", []) + if ignore_logs: + actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) + + if client.get("ignoreExtraMessages", False): + actual_logs = actual_logs[: len(client["messages"])] + self.assertEqual( + len(client["messages"]), + len(actual_logs), + f"expected {client['messages']} but got {actual_logs}", + ) + for expected_msg, actual_msg in zip(client["messages"], actual_logs): + expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") + + if "failureIsRedacted" in expected_msg: + self.assertIn("failure", actual_data) + should_redact = expected_msg.pop("failureIsRedacted") + if should_redact: + actual_fields = set(json_util.loads(actual_data["failure"]).keys()) + self.assertTrue( + {"code", "codeName", "errorLabels"}.issuperset(actual_fields) + ) + + self.match_evaluator.match_result(expected_data, actual_data) + self.match_evaluator.match_result(expected_msg, actual_msg) + + async def verify_outcome(self, spec): + for collection_data in spec: + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + expected_documents = collection_data["documents"] + + coll = self.client.get_database(db_name).get_collection( + coll_name, + read_preference=ReadPreference.PRIMARY, + read_concern=ReadConcern(level="local"), + ) + + if expected_documents: + sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) + actual_documents = await coll.find({}, sort=[("_id", ASCENDING)]).to_list() + self.assertListEqual(sorted_expected_documents, actual_documents) + + async def run_scenario(self, spec, uri=None): + if "csot" in self.id().lower() and SKIP_CSOT_TESTS: + raise unittest.SkipTest("SKIP_CSOT_TESTS is set, skipping...") + + # Kill all sessions before and after each test to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + await self.kill_all_sessions() + self.addAsyncCleanup(self.kill_all_sessions) + + if "csot" in self.id().lower(): + # Retry CSOT tests up to 2 times to deal with flakey tests. + attempts = 3 + for i in range(attempts): + try: + return await self._run_scenario(spec, uri) + except AssertionError: + if i < attempts - 1: + print( + f"Retrying after attempt {i+1} of {self.id()} failed with:\n" + f"{traceback.format_exc()}", + file=sys.stderr, + ) + await self.asyncSetUp() + continue + raise + return None + else: + await self._run_scenario(spec, uri) + return None + + async def _run_scenario(self, spec, uri=None): + # maybe skip test manually + self.maybe_skip_test(spec) + + # process test-level runOnRequirements + run_on_spec = spec.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest("runOnRequirements not satisfied") + + # process skipReason + skip_reason = spec.get("skipReason", None) + if skip_reason is not None: + raise unittest.SkipTest(f"{skip_reason}") + + # process createEntities + self._uri = uri + self.entity_map = EntityMapUtil(self) + await self.entity_map.create_entities_from_spec( + self.TEST_SPEC.get("createEntities", []), uri=uri + ) + # process initialData + if "initialData" in self.TEST_SPEC: + await self.insert_initial_data(self.TEST_SPEC["initialData"]) + self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime") + await self.entity_map.advance_cluster_times() + + if "expectLogMessages" in spec: + expect_log_messages = spec["expectLogMessages"] + self.assertTrue(expect_log_messages, "expectEvents must be non-empty") + await self.check_log_messages(spec["operations"], expect_log_messages) + else: + # process operations + await self.run_operations(spec["operations"]) + + # process expectEvents + if "expectEvents" in spec: + expect_events = spec["expectEvents"] + self.assertTrue(expect_events, "expectEvents must be non-empty") + self.check_events(expect_events) + + # process outcome + await self.verify_outcome(spec.get("outcome", [])) + + +class UnifiedSpecTestMeta(type): + """Metaclass for generating test classes.""" + + TEST_SPEC: Any + EXPECTED_FAILURES: Any + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create_test(spec): + async def test_case(self): + await self.run_scenario(spec) + + return test_case + + for test_spec in cls.TEST_SPEC["tests"]: + description = test_spec["description"] + test_name = "test_{}".format( + description.strip(". ").replace(" ", "_").replace(".", "_") + ) + test_method = create_test(copy.deepcopy(test_spec)) + test_method.__name__ = str(test_name) + + for fail_pattern in cls.EXPECTED_FAILURES: + if re.search(fail_pattern, description): + test_method = unittest.expectedFailure(test_method) + break + + setattr(cls, test_name, test_method) + + +_ALL_MIXIN_CLASSES = [ + UnifiedSpecTestMixinV1, + # add mixin classes for new schema major versions here +] + + +_SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { + KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES +} + + +def generate_test_classes( + test_path, + module=__name__, + class_name_prefix="", + expected_failures=[], # noqa: B006 + bypass_test_generation_errors=False, + **kwargs, +): + """Method for generating test classes. Returns a dictionary where keys are + the names of test classes and values are the test class objects. + """ + test_klasses = {} + + def test_base_class_factory(test_spec): + """Utility that creates the base class to use for test generation. + This is needed to ensure that cls.TEST_SPEC is appropriately set when + the metaclass __init__ is invoked. + """ + + class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore + TEST_SPEC = test_spec + EXPECTED_FAILURES = expected_failures + + return SpecTestBase + + for dirpath, _, filenames in os.walk(test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + fpath = os.path.join(dirpath, filename) + with open(fpath) as scenario_stream: + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = json_util.loads(scenario_stream.read(), json_options=opts) + + test_type = os.path.splitext(filename)[0] + snake_class_name = "Test{}_{}_{}".format( + class_name_prefix, + dirname.replace("-", "_"), + test_type.replace("-", "_").replace(".", "_"), + ) + class_name = snake_to_camel(snake_class_name) + + try: + schema_version = Version.from_string(scenario_def["schemaVersion"]) + mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(schema_version[0]) + if mixin_class is None: + raise ValueError( + f"test file '{fpath}' has unsupported schemaVersion '{schema_version}'" + ) + module_dict = {"__module__": module, "TEST_PATH": test_path} + module_dict.update(kwargs) + test_klasses[class_name] = type( + class_name, + ( + mixin_class, + test_base_class_factory(scenario_def), + ), + module_dict, + ) + except Exception: + if bypass_test_generation_errors: + continue + raise + + return test_klasses diff --git a/test/unified_format.py b/test/unified_format.py index 62211d3d25..6a19082b86 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -18,41 +18,41 @@ """ from __future__ import annotations +import asyncio import binascii -import collections import copy -import datetime import functools import os import re import sys import time import traceback -import types -from collections import abc, defaultdict +from asyncio import iscoroutinefunction +from collections import defaultdict from test import ( IntegrationTest, client_context, client_knobs, unittest, ) -from test.helpers import ( - AWS_CREDS, - AWS_CREDS_2, - AZURE_CREDS, - CA_PEM, - CLIENT_PEM, - GCP_CREDS, - KMIP_CREDS, - LOCAL_MASTER_KEY, - client_knobs, +from test.unified_format_shared import ( + IS_INTERRUPTED, + KMS_TLS_OPTS, + PLACEHOLDER_MAP, + SKIP_CSOT_TESTS, + EventListenerUtil, + MatchEvaluatorUtil, + coerce_result, + parse_bulk_write_error_result, + parse_bulk_write_result, + parse_client_bulk_write_error_result, + parse_collection_or_database_options, + with_metaclass, ) from test.utils import ( - CMAPListener, camel_to_snake, camel_to_snake_args, get_pool, - parse_collection_options, parse_spec_options, prepare_spec_arguments, snake_to_camel, @@ -60,14 +60,12 @@ ) from test.utils_spec_runner import SpecRunnerThread from test.version import Version -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional import pymongo -from bson import SON, Code, DBRef, Decimal128, Int64, MaxKey, MinKey, json_util -from bson.binary import Binary +from bson import SON, json_util from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.objectid import ObjectId -from bson.regex import RE_TYPE, Regex from gridfs import GridFSBucket, GridOut from pymongo import ASCENDING, CursorType, MongoClient, _csot from pymongo.encryption_options import _HAVE_PYMONGOCRYPT @@ -83,55 +81,14 @@ PyMongoError, ) from pymongo.monitoring import ( - _SENSITIVE_COMMANDS, - CommandFailedEvent, - CommandListener, CommandStartedEvent, - CommandSucceededEvent, - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, - ServerClosedEvent, - ServerDescriptionChangedEvent, - ServerHeartbeatFailedEvent, - ServerHeartbeatListener, - ServerHeartbeatStartedEvent, - ServerHeartbeatSucceededEvent, - ServerListener, - ServerOpeningEvent, - TopologyClosedEvent, - TopologyDescriptionChangedEvent, - TopologyEvent, - TopologyListener, - TopologyOpenedEvent, - _CommandEvent, - _ConnectionEvent, - _PoolEvent, - _ServerEvent, - _ServerHeartbeatEvent, ) from pymongo.operations import ( - DeleteMany, - DeleteOne, - InsertOne, - ReplaceOne, SearchIndexModel, - UpdateMany, - UpdateOne, ) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, ClientBulkWriteResult from pymongo.server_api import ServerApi -from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.synchronous.change_stream import ChangeStream @@ -140,85 +97,12 @@ from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database from pymongo.synchronous.encryption import ClientEncryption +from pymongo.synchronous.helpers import next from pymongo.topology_description import TopologyDescription from pymongo.typings import _Address from pymongo.write_concern import WriteConcern -SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS") - -JSON_OPTS = json_util.JSONOptions(tz_aware=False) - -IS_INTERRUPTED = False - -KMS_TLS_OPTS = { - "kmip": { - "tlsCAFile": CA_PEM, - "tlsCertificateKeyFile": CLIENT_PEM, - } -} - - -# Build up a placeholder maps. -PLACEHOLDER_MAP = {} -for provider_name, provider_data in [ - ("local", {"key": LOCAL_MASTER_KEY}), - ("local:name1", {"key": LOCAL_MASTER_KEY}), - ("aws", AWS_CREDS), - ("aws:name1", AWS_CREDS), - ("aws:name2", AWS_CREDS_2), - ("azure", AZURE_CREDS), - ("azure:name1", AZURE_CREDS), - ("gcp", GCP_CREDS), - ("gcp:name1", GCP_CREDS), - ("kmip", KMIP_CREDS), - ("kmip:name1", KMIP_CREDS), -]: - for key, value in provider_data.items(): - placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" - PLACEHOLDER_MAP[placeholder] = value - -OIDC_ENV = os.environ.get("OIDC_ENV", "test") -if OIDC_ENV == "test": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"} -elif OIDC_ENV == "azure": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { - "ENVIRONMENT": "azure", - "TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"], - } -elif OIDC_ENV == "gcp": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { - "ENVIRONMENT": "gcp", - "TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"], - } - - -def interrupt_loop(): - global IS_INTERRUPTED - IS_INTERRUPTED = True - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass. - - Vendored from six: https://github.com/benjaminp/six/blob/master/six.py - """ - - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - def __new__(cls, name, this_bases, d): - # __orig_bases__ is required by PEP 560. - resolved_bases = types.resolve_bases(bases) - if resolved_bases is not bases: - d["__orig_bases__"] = bases - return meta(name, resolved_bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - - return type.__new__(metaclass, "temporary_class", (), {}) +_IS_SYNC = True def is_run_on_requirement_satisfied(requirement): @@ -283,77 +167,6 @@ def is_run_on_requirement_satisfied(requirement): ) -def parse_collection_or_database_options(options): - return parse_collection_options(options) - - -def parse_bulk_write_result(result): - upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} - return { - "deletedCount": result.deleted_count, - "insertedCount": result.inserted_count, - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": result.upserted_count, - "upsertedIds": upserted_ids, - } - - -def parse_client_bulk_write_individual(op_type, result): - if op_type == "insert": - return {"insertedId": result.inserted_id} - if op_type == "update": - if result.upserted_id: - return { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedId": result.upserted_id, - } - else: - return { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - } - if op_type == "delete": - return { - "deletedCount": result.deleted_count, - } - - -def parse_client_bulk_write_result(result): - insert_results, update_results, delete_results = {}, {}, {} - if result.has_verbose_results: - for idx, res in result.insert_results.items(): - insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res) - for idx, res in result.update_results.items(): - update_results[str(idx)] = parse_client_bulk_write_individual("update", res) - for idx, res in result.delete_results.items(): - delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res) - - return { - "deletedCount": result.deleted_count, - "insertedCount": result.inserted_count, - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": result.upserted_count, - "insertResults": insert_results, - "updateResults": update_results, - "deleteResults": delete_results, - } - - -def parse_bulk_write_error_result(error): - write_result = BulkWriteResult(error.details, True) - return parse_bulk_write_result(write_result) - - -def parse_client_bulk_write_error_result(error): - write_result = error.partial_result - if not write_result: - return None - return parse_client_bulk_write_result(write_result) - - class NonLazyCursor: """A find cursor proxy that creates the remote cursor when initialized.""" @@ -361,7 +174,16 @@ def __init__(self, find_cursor, client): self.client = client self.find_cursor = find_cursor # Create the server side cursor. - self.first_result = next(find_cursor, None) + self.first_result = None + + @classmethod + def create(cls, find_cursor, client): + cursor = cls(find_cursor, client) + try: + cursor.first_result = next(cursor.find_cursor) + except StopIteration: + cursor.first_result = None + return cursor @property def alive(self): @@ -382,105 +204,6 @@ def close(self): self.client = None -class EventListenerUtil( - CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener -): - def __init__( - self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map - ): - self._event_types = {name.lower() for name in observe_events} - if observe_sensitive_commands: - self._observe_sensitive_commands = True - self._ignore_commands = set(ignore_commands) - else: - self._observe_sensitive_commands = False - self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) - self._ignore_commands.add("configurefailpoint") - self._event_mapping = collections.defaultdict(list) - self.entity_map = entity_map - if store_events: - for i in store_events: - id = i["id"] - events = (i.lower() for i in i["events"]) - for i in events: - self._event_mapping[i].append(id) - self.entity_map[id] = [] - super().__init__() - - def get_events(self, event_type): - assert event_type in ("command", "cmap", "sdam", "all"), event_type - if event_type == "all": - return list(self.events) - if event_type == "command": - return [e for e in self.events if isinstance(e, _CommandEvent)] - if event_type == "cmap": - return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))] - return [ - e - for e in self.events - if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent)) - ] - - def add_event(self, event): - event_name = type(event).__name__.lower() - if event_name in self._event_types: - super().add_event(event) - for id in self._event_mapping[event_name]: - self.entity_map[id].append( - { - "name": type(event).__name__, - "observedAt": time.time(), - "description": repr(event), - } - ) - - def _command_event(self, event): - if event.command_name.lower() not in self._ignore_commands: - self.add_event(event) - - def started(self, event): - if isinstance(event, CommandStartedEvent): - if event.command == {}: - # Command is redacted. Observe only if flag is set. - if self._observe_sensitive_commands: - self._command_event(event) - else: - self._command_event(event) - else: - self.add_event(event) - - def succeeded(self, event): - if isinstance(event, CommandSucceededEvent): - if event.reply == {}: - # Command is redacted. Observe only if flag is set. - if self._observe_sensitive_commands: - self._command_event(event) - else: - self._command_event(event) - else: - self.add_event(event) - - def failed(self, event): - if isinstance(event, CommandFailedEvent): - self._command_event(event) - else: - self.add_event(event) - - def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None: - self.add_event(event) - - def description_changed( - self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent] - ) -> None: - self.add_event(event) - - def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None: - self.add_event(event) - - def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: - self.add_event(event) - - class EntityMapUtil: """Utility class that implements an entity map as per the unified test format specification. @@ -692,353 +415,12 @@ def get_lsid_for_session(self, session_name): def advance_cluster_times(self) -> None: """Manually synchronize entities when desired""" if not self._cluster_time: - self._cluster_time = self.test.client.admin.command("ping").get("$clusterTime") + self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): if isinstance(entity, ClientSession) and self._cluster_time: entity.advance_cluster_time(self._cluster_time) -binary_types = (Binary, bytes) -long_types = (Int64,) -unicode_type = str - - -BSON_TYPE_ALIAS_MAP = { - # https://mongodb.com/docs/manual/reference/operator/query/type/ - # https://pymongo.readthedocs.io/en/stable/api/bson/index.html - "double": (float,), - "string": (str,), - "object": (abc.Mapping,), - "array": (abc.MutableSequence,), - "binData": binary_types, - "undefined": (type(None),), - "objectId": (ObjectId,), - "bool": (bool,), - "date": (datetime.datetime,), - "null": (type(None),), - "regex": (Regex, RE_TYPE), - "dbPointer": (DBRef,), - "javascript": (unicode_type, Code), - "symbol": (unicode_type,), - "javascriptWithScope": (unicode_type, Code), - "int": (int,), - "long": (Int64,), - "decimal": (Decimal128,), - "maxKey": (MaxKey,), - "minKey": (MinKey,), -} - - -class MatchEvaluatorUtil: - """Utility class that implements methods for evaluating matches as per - the unified test format specification. - """ - - def __init__(self, test_class): - self.test = test_class - - def _operation_exists(self, spec, actual, key_to_compare): - if spec is True: - if key_to_compare is None: - assert actual is not None - else: - self.test.assertIn(key_to_compare, actual) - elif spec is False: - if key_to_compare is None: - assert actual is None - else: - self.test.assertNotIn(key_to_compare, actual) - else: - self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") - - def __type_alias_to_type(self, alias): - if alias not in BSON_TYPE_ALIAS_MAP: - self.test.fail(f"Unrecognized BSON type alias {alias}") - return BSON_TYPE_ALIAS_MAP[alias] - - def _operation_type(self, spec, actual, key_to_compare): - if isinstance(spec, abc.MutableSequence): - permissible_types = tuple( - [t for alias in spec for t in self.__type_alias_to_type(alias)] - ) - else: - permissible_types = self.__type_alias_to_type(spec) - value = actual[key_to_compare] if key_to_compare else actual - self.test.assertIsInstance(value, permissible_types) - - def _operation_matchesEntity(self, spec, actual, key_to_compare): - expected_entity = self.test.entity_map[spec] - self.test.assertEqual(expected_entity, actual[key_to_compare]) - - def _operation_matchesHexBytes(self, spec, actual, key_to_compare): - expected = binascii.unhexlify(spec) - value = actual[key_to_compare] if key_to_compare else actual - self.test.assertEqual(value, expected) - - def _operation_unsetOrMatches(self, spec, actual, key_to_compare): - if key_to_compare is None and not actual: - # top-level document can be None when unset - return - - if key_to_compare not in actual: - # we add a dummy value for the compared key to pass map size check - actual[key_to_compare] = "dummyValue" - return - self.match_result(spec, actual[key_to_compare], in_recursive_call=True) - - def _operation_sessionLsid(self, spec, actual, key_to_compare): - expected_lsid = self.test.entity_map.get_lsid_for_session(spec) - self.test.assertEqual(expected_lsid, actual[key_to_compare]) - - def _operation_lte(self, spec, actual, key_to_compare): - if key_to_compare not in actual: - self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}") - self.test.assertLessEqual(actual[key_to_compare], spec) - - def _operation_matchAsDocument(self, spec, actual, key_to_compare): - self._match_document(spec, json_util.loads(actual[key_to_compare]), False) - - def _operation_matchAsRoot(self, spec, actual, key_to_compare): - self._match_document(spec, actual, True) - - def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): - method_name = "_operation_{}".format(opname.strip("$")) - try: - method = getattr(self, method_name) - except AttributeError: - self.test.fail(f"Unsupported special matching operator {opname}") - else: - method(spec, actual, key_to_compare) - - def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): - """Returns True if a special operation is evaluated, False - otherwise. If the ``expectation`` map contains a single key, - value pair we check it for a special operation. - If given, ``key_to_compare`` is assumed to be the key in - ``expectation`` whose corresponding value needs to be - evaluated for a possible special operation. ``key_to_compare`` - is ignored when ``expectation`` has only one key. - """ - if not isinstance(expectation, abc.Mapping): - return False - - is_special_op, opname, spec = False, False, False - - if key_to_compare is not None: - if key_to_compare.startswith("$$"): - is_special_op = True - opname = key_to_compare - spec = expectation[key_to_compare] - key_to_compare = None - else: - nested = expectation[key_to_compare] - if isinstance(nested, abc.Mapping) and len(nested) == 1: - opname, spec = next(iter(nested.items())) - if opname.startswith("$$"): - is_special_op = True - elif len(expectation) == 1: - opname, spec = next(iter(expectation.items())) - if opname.startswith("$$"): - is_special_op = True - key_to_compare = None - - if is_special_op: - self._evaluate_special_operation( - opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare - ) - return True - - return False - - def _match_document(self, expectation, actual, is_root, test=False): - if self._evaluate_if_special_operation(expectation, actual): - return - - self.test.assertIsInstance(actual, abc.Mapping) - for key, value in expectation.items(): - if self._evaluate_if_special_operation(expectation, actual, key): - continue - - self.test.assertIn(key, actual) - if not self.match_result(value, actual[key], in_recursive_call=True, test=test): - return False - - if not is_root: - expected_keys = set(expectation.keys()) - for key, value in expectation.items(): - if value == {"$$exists": False}: - expected_keys.remove(key) - if test: - self.test.assertEqual(expected_keys, set(actual.keys())) - else: - return set(expected_keys).issubset(set(actual.keys())) - return True - - def match_result(self, expectation, actual, in_recursive_call=False, test=True): - if isinstance(expectation, abc.Mapping): - return self._match_document( - expectation, actual, is_root=not in_recursive_call, test=test - ) - - if isinstance(expectation, abc.MutableSequence): - self.test.assertIsInstance(actual, abc.MutableSequence) - for e, a in zip(expectation, actual): - if isinstance(e, abc.Mapping): - self._match_document(e, a, is_root=not in_recursive_call, test=test) - else: - self.match_result(e, a, in_recursive_call=True, test=test) - return None - - # account for flexible numerics in element-wise comparison - if isinstance(expectation, int) or isinstance(expectation, float): - if test: - self.test.assertEqual(expectation, actual) - else: - return expectation == actual - return None - else: - if test: - self.test.assertIsInstance(actual, type(expectation)) - self.test.assertEqual(expectation, actual) - else: - return isinstance(actual, type(expectation)) and expectation == actual - return None - - def match_server_description(self, actual: ServerDescription, spec: dict) -> None: - for field, expected in spec.items(): - field = camel_to_snake(field) - if field == "type": - field = "server_type_name" - self.test.assertEqual(getattr(actual, field), expected) - - def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None: - for field, expected in spec.items(): - field = camel_to_snake(field) - if field == "type": - field = "topology_type_name" - self.test.assertEqual(getattr(actual, field), expected) - - def match_event_fields(self, actual: Any, spec: dict) -> None: - for field, expected in spec.items(): - if field == "command" and isinstance(actual, CommandStartedEvent): - command = spec["command"] - if command: - self.match_result(command, actual.command) - continue - if field == "reply" and isinstance(actual, CommandSucceededEvent): - reply = spec["reply"] - if reply: - self.match_result(reply, actual.reply) - continue - if field == "hasServiceId": - if spec["hasServiceId"]: - self.test.assertIsNotNone(actual.service_id) - self.test.assertIsInstance(actual.service_id, ObjectId) - else: - self.test.assertIsNone(actual.service_id) - continue - if field == "hasServerConnectionId": - if spec["hasServerConnectionId"]: - self.test.assertIsNotNone(actual.server_connection_id) - self.test.assertIsInstance(actual.server_connection_id, int) - else: - self.test.assertIsNone(actual.server_connection_id) - continue - if field in ("previousDescription", "newDescription"): - if isinstance(actual, ServerDescriptionChangedEvent): - self.match_server_description( - getattr(actual, camel_to_snake(field)), spec[field] - ) - continue - if isinstance(actual, TopologyDescriptionChangedEvent): - self.match_topology_description( - getattr(actual, camel_to_snake(field)), spec[field] - ) - continue - - if field == "interruptInUseConnections": - field = "interrupt_connections" - else: - field = camel_to_snake(field) - self.test.assertEqual(getattr(actual, field), expected) - - def match_event(self, expectation, actual): - name, spec = next(iter(expectation.items())) - if name == "commandStartedEvent": - self.test.assertIsInstance(actual, CommandStartedEvent) - elif name == "commandSucceededEvent": - self.test.assertIsInstance(actual, CommandSucceededEvent) - elif name == "commandFailedEvent": - self.test.assertIsInstance(actual, CommandFailedEvent) - elif name == "poolCreatedEvent": - self.test.assertIsInstance(actual, PoolCreatedEvent) - elif name == "poolReadyEvent": - self.test.assertIsInstance(actual, PoolReadyEvent) - elif name == "poolClearedEvent": - self.test.assertIsInstance(actual, PoolClearedEvent) - self.test.assertIsInstance(actual.interrupt_connections, bool) - elif name == "poolClosedEvent": - self.test.assertIsInstance(actual, PoolClosedEvent) - elif name == "connectionCreatedEvent": - self.test.assertIsInstance(actual, ConnectionCreatedEvent) - elif name == "connectionReadyEvent": - self.test.assertIsInstance(actual, ConnectionReadyEvent) - elif name == "connectionClosedEvent": - self.test.assertIsInstance(actual, ConnectionClosedEvent) - elif name == "connectionCheckOutStartedEvent": - self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) - elif name == "connectionCheckOutFailedEvent": - self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) - elif name == "connectionCheckedOutEvent": - self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) - elif name == "connectionCheckedInEvent": - self.test.assertIsInstance(actual, ConnectionCheckedInEvent) - elif name == "serverDescriptionChangedEvent": - self.test.assertIsInstance(actual, ServerDescriptionChangedEvent) - elif name == "serverHeartbeatStartedEvent": - self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent) - elif name == "serverHeartbeatSucceededEvent": - self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent) - elif name == "serverHeartbeatFailedEvent": - self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent) - elif name == "topologyDescriptionChangedEvent": - self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent) - elif name == "topologyOpeningEvent": - self.test.assertIsInstance(actual, TopologyOpenedEvent) - elif name == "topologyClosedEvent": - self.test.assertIsInstance(actual, TopologyClosedEvent) - else: - raise Exception(f"Unsupported event type {name}") - - self.match_event_fields(actual, spec) - - -def coerce_result(opname, result): - """Convert a pymongo result into the spec's result format.""" - if hasattr(result, "acknowledged") and not result.acknowledged: - return {"acknowledged": False} - if opname == "bulkWrite": - return parse_bulk_write_result(result) - if opname == "clientBulkWrite": - return parse_client_bulk_write_result(result) - if opname == "insertOne": - return {"insertedId": result.inserted_id} - if opname == "insertMany": - return dict(enumerate(result.inserted_ids)) - if opname in ("deleteOne", "deleteMany"): - return {"deletedCount": result.deleted_count} - if opname in ("updateOne", "updateMany", "replaceOne"): - value = { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": 0 if result.upserted_id is None else 1, - } - if result.upserted_id is not None: - value["upsertedId"] = result.upserted_id - return value - return result - - class UnifiedSpecTestMixinV1(IntegrationTest): """Mixin class to run test cases from test specification files. @@ -1090,9 +472,9 @@ def insert_initial_data(self, initial_data): db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - def setUpClass(cls): + def _setup_class(cls): # super call creates internal client cls.client - super().setUpClass() + super()._setup_class() # process file-level runOnRequirements run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) if not cls.should_run_on(run_on_spec): @@ -1125,11 +507,11 @@ def setUpClass(cls): cls.knobs.enable() @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.knobs.disable() for client in cls.mongos_clients: client.close() - super().tearDownClass() + super()._tearDown_class() def setUp(self): super().setUp() @@ -1391,7 +773,7 @@ def _databaseOperation_createCollection(self, target, *args, **kwargs): def __entityOperation_aggregate(self, target, *args, **kwargs): self.__raise_if_unsupported("aggregate", target, Database, Collection) - return list(target.aggregate(*args, **kwargs)) + return (target.aggregate(*args, **kwargs)).to_list() def _databaseOperation_aggregate(self, target, *args, **kwargs): return self.__entityOperation_aggregate(target, *args, **kwargs) @@ -1402,13 +784,13 @@ def _collectionOperation_aggregate(self, target, *args, **kwargs): def _collectionOperation_find(self, target, *args, **kwargs): self.__raise_if_unsupported("find", target, Collection) find_cursor = target.find(*args, **kwargs) - return list(find_cursor) + return find_cursor.to_list() def _collectionOperation_createFindCursor(self, target, *args, **kwargs): self.__raise_if_unsupported("find", target, Collection) if "filter" not in kwargs: self.fail('createFindCursor requires a "filter" argument') - cursor = NonLazyCursor(target.find(*args, **kwargs), target.database.client) + cursor = NonLazyCursor.create(target.find(*args, **kwargs), target.database.client) self.addCleanup(cursor.close) return cursor @@ -1418,7 +800,7 @@ def _collectionOperation_count(self, target, *args, **kwargs): def _collectionOperation_listIndexes(self, target, *args, **kwargs): if "batch_size" in kwargs: self.skipTest("PyMongo does not support batch_size for list_indexes") - return list(target.list_indexes(*args, **kwargs)) + return (target.list_indexes(*args, **kwargs)).to_list() def _collectionOperation_listIndexNames(self, target, *args, **kwargs): self.skipTest("PyMongo does not support list_index_names") @@ -1430,7 +812,7 @@ def _collectionOperation_createSearchIndexes(self, target, *args, **kwargs): def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs): name = kwargs.get("name") agg_kwargs = kwargs.get("aggregation_options", dict()) - return list(target.list_search_indexes(name, **agg_kwargs)) + return (target.list_search_indexes(name, **agg_kwargs)).to_list() def _sessionOperation_withTransaction(self, target, *args, **kwargs): if client_context.storage_engine == "mmapv1": @@ -1470,7 +852,7 @@ def _clientEncryptionOperation_createDataKey(self, target, *args, **kwargs): return target.create_data_key(*args, **kwargs) def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs): - return list(target.get_keys(*args, **kwargs)) + return (target.get_keys(*args, **kwargs)).to_list() def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs): result = target.delete_key(*args, **kwargs) @@ -1516,7 +898,7 @@ def _bucketOperation_uploadWithId(self, target: GridFSBucket, *args: Any, **kwar def _bucketOperation_find( self, target: GridFSBucket, *args: Any, **kwargs: Any ) -> List[GridOut]: - return list(target.find(*args, **kwargs)) + return target.find(*args, **kwargs).to_list() def run_entity_operation(self, spec): target = self.entity_map[spec["object"]] @@ -1849,7 +1231,10 @@ def run_special_operation(self, spec): except AttributeError: self.fail(f"Unsupported special test operation {opname}") else: - method(spec["arguments"]) + if iscoroutinefunction(method): + method(spec["arguments"]) + else: + method(spec["arguments"]) def run_operations(self, spec): for op in spec: @@ -1985,7 +1370,7 @@ def verify_outcome(self, spec): if expected_documents: sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) - actual_documents = list(coll.find({}, sort=[("_id", ASCENDING)])) + actual_documents = coll.find({}, sort=[("_id", ASCENDING)]).to_list() self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): @@ -2040,7 +1425,7 @@ def _run_scenario(self, spec, uri=None): # process initialData if "initialData" in self.TEST_SPEC: self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = self.client.admin.command("ping").get("$clusterTime") + self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime") self.entity_map.advance_cluster_times() if "expectLogMessages" in spec: diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py new file mode 100644 index 0000000000..d11624476d --- /dev/null +++ b/test/unified_format_shared.py @@ -0,0 +1,679 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utility functions and constants for the unified test format runner. + +https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst +""" +from __future__ import annotations + +import binascii +import collections +import datetime +import os +import time +import types +from collections import abc +from test.helpers import ( + AWS_CREDS, + AWS_CREDS_2, + AZURE_CREDS, + CA_PEM, + CLIENT_PEM, + GCP_CREDS, + KMIP_CREDS, + LOCAL_MASTER_KEY, +) +from test.utils import CMAPListener, camel_to_snake, parse_collection_options +from typing import Any, Union + +from bson import ( + RE_TYPE, + Binary, + Code, + DBRef, + Decimal128, + Int64, + MaxKey, + MinKey, + ObjectId, + Regex, + json_util, +) +from pymongo.monitoring import ( + _SENSITIVE_COMMANDS, + CommandFailedEvent, + CommandListener, + CommandStartedEvent, + CommandSucceededEvent, + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, + ServerClosedEvent, + ServerDescriptionChangedEvent, + ServerHeartbeatFailedEvent, + ServerHeartbeatListener, + ServerHeartbeatStartedEvent, + ServerHeartbeatSucceededEvent, + ServerListener, + ServerOpeningEvent, + TopologyClosedEvent, + TopologyDescriptionChangedEvent, + TopologyEvent, + TopologyListener, + TopologyOpenedEvent, + _CommandEvent, + _ConnectionEvent, + _PoolEvent, + _ServerEvent, + _ServerHeartbeatEvent, +) +from pymongo.results import BulkWriteResult +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TopologyDescription + +SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS") + +JSON_OPTS = json_util.JSONOptions(tz_aware=False) + +IS_INTERRUPTED = False + +KMS_TLS_OPTS = { + "kmip": { + "tlsCAFile": CA_PEM, + "tlsCertificateKeyFile": CLIENT_PEM, + } +} + + +# Build up a placeholder maps. +PLACEHOLDER_MAP = {} +for provider_name, provider_data in [ + ("local", {"key": LOCAL_MASTER_KEY}), + ("local:name1", {"key": LOCAL_MASTER_KEY}), + ("aws", AWS_CREDS), + ("aws:name1", AWS_CREDS), + ("aws:name2", AWS_CREDS_2), + ("azure", AZURE_CREDS), + ("azure:name1", AZURE_CREDS), + ("gcp", GCP_CREDS), + ("gcp:name1", GCP_CREDS), + ("kmip", KMIP_CREDS), + ("kmip:name1", KMIP_CREDS), +]: + for key, value in provider_data.items(): + placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" + PLACEHOLDER_MAP[placeholder] = value + +OIDC_ENV = os.environ.get("OIDC_ENV", "test") +if OIDC_ENV == "test": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"} +elif OIDC_ENV == "azure": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"], + } +elif OIDC_ENV == "gcp": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "ENVIRONMENT": "gcp", + "TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"], + } + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass. + + Vendored from six: https://github.com/benjaminp/six/blob/master/six.py + """ + + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + def __new__(cls, name, this_bases, d): + # __orig_bases__ is required by PEP 560. + resolved_bases = types.resolve_bases(bases) + if resolved_bases is not bases: + d["__orig_bases__"] = bases + return meta(name, resolved_bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + + return type.__new__(metaclass, "temporary_class", (), {}) + + +def parse_collection_or_database_options(options): + return parse_collection_options(options) + + +def parse_bulk_write_result(result): + upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} + return { + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "upsertedIds": upserted_ids, + } + + +def parse_client_bulk_write_individual(op_type, result): + if op_type == "insert": + return {"insertedId": result.inserted_id} + if op_type == "update": + if result.upserted_id: + return { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedId": result.upserted_id, + } + else: + return { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + } + if op_type == "delete": + return { + "deletedCount": result.deleted_count, + } + + +def parse_client_bulk_write_result(result): + insert_results, update_results, delete_results = {}, {}, {} + if result.has_verbose_results: + for idx, res in result.insert_results.items(): + insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res) + for idx, res in result.update_results.items(): + update_results[str(idx)] = parse_client_bulk_write_individual("update", res) + for idx, res in result.delete_results.items(): + delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res) + + return { + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "insertResults": insert_results, + "updateResults": update_results, + "deleteResults": delete_results, + } + + +def parse_bulk_write_error_result(error): + write_result = BulkWriteResult(error.details, True) + return parse_bulk_write_result(write_result) + + +def parse_client_bulk_write_error_result(error): + write_result = error.partial_result + if not write_result: + return None + return parse_client_bulk_write_result(write_result) + + +class EventListenerUtil( + CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener +): + def __init__( + self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map + ): + self._event_types = {name.lower() for name in observe_events} + if observe_sensitive_commands: + self._observe_sensitive_commands = True + self._ignore_commands = set(ignore_commands) + else: + self._observe_sensitive_commands = False + self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) + self._ignore_commands.add("configurefailpoint") + self._event_mapping = collections.defaultdict(list) + self.entity_map = entity_map + if store_events: + for i in store_events: + id = i["id"] + events = (i.lower() for i in i["events"]) + for i in events: + self._event_mapping[i].append(id) + self.entity_map[id] = [] + super().__init__() + + def get_events(self, event_type): + assert event_type in ("command", "cmap", "sdam", "all"), event_type + if event_type == "all": + return list(self.events) + if event_type == "command": + return [e for e in self.events if isinstance(e, _CommandEvent)] + if event_type == "cmap": + return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))] + return [ + e + for e in self.events + if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent)) + ] + + def add_event(self, event): + event_name = type(event).__name__.lower() + if event_name in self._event_types: + super().add_event(event) + for id in self._event_mapping[event_name]: + self.entity_map[id].append( + { + "name": type(event).__name__, + "observedAt": time.time(), + "description": repr(event), + } + ) + + def _command_event(self, event): + if event.command_name.lower() not in self._ignore_commands: + self.add_event(event) + + def started(self, event): + if isinstance(event, CommandStartedEvent): + if event.command == {}: + # Command is redacted. Observe only if flag is set. + if self._observe_sensitive_commands: + self._command_event(event) + else: + self._command_event(event) + else: + self.add_event(event) + + def succeeded(self, event): + if isinstance(event, CommandSucceededEvent): + if event.reply == {}: + # Command is redacted. Observe only if flag is set. + if self._observe_sensitive_commands: + self._command_event(event) + else: + self._command_event(event) + else: + self.add_event(event) + + def failed(self, event): + if isinstance(event, CommandFailedEvent): + self._command_event(event) + else: + self.add_event(event) + + def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None: + self.add_event(event) + + def description_changed( + self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent] + ) -> None: + self.add_event(event) + + def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None: + self.add_event(event) + + def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: + self.add_event(event) + + +binary_types = (Binary, bytes) +long_types = (Int64,) +unicode_type = str + + +BSON_TYPE_ALIAS_MAP = { + # https://mongodb.com/docs/manual/reference/operator/query/type/ + # https://pymongo.readthedocs.io/en/stable/api/bson/index.html + "double": (float,), + "string": (str,), + "object": (abc.Mapping,), + "array": (abc.MutableSequence,), + "binData": binary_types, + "undefined": (type(None),), + "objectId": (ObjectId,), + "bool": (bool,), + "date": (datetime.datetime,), + "null": (type(None),), + "regex": (Regex, RE_TYPE), + "dbPointer": (DBRef,), + "javascript": (unicode_type, Code), + "symbol": (unicode_type,), + "javascriptWithScope": (unicode_type, Code), + "int": (int,), + "long": (Int64,), + "decimal": (Decimal128,), + "maxKey": (MaxKey,), + "minKey": (MinKey,), +} + + +class MatchEvaluatorUtil: + """Utility class that implements methods for evaluating matches as per + the unified test format specification. + """ + + def __init__(self, test_class): + self.test = test_class + + def _operation_exists(self, spec, actual, key_to_compare): + if spec is True: + if key_to_compare is None: + assert actual is not None + else: + self.test.assertIn(key_to_compare, actual) + elif spec is False: + if key_to_compare is None: + assert actual is None + else: + self.test.assertNotIn(key_to_compare, actual) + else: + self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") + + def __type_alias_to_type(self, alias): + if alias not in BSON_TYPE_ALIAS_MAP: + self.test.fail(f"Unrecognized BSON type alias {alias}") + return BSON_TYPE_ALIAS_MAP[alias] + + def _operation_type(self, spec, actual, key_to_compare): + if isinstance(spec, abc.MutableSequence): + permissible_types = tuple( + [t for alias in spec for t in self.__type_alias_to_type(alias)] + ) + else: + permissible_types = self.__type_alias_to_type(spec) + value = actual[key_to_compare] if key_to_compare else actual + self.test.assertIsInstance(value, permissible_types) + + def _operation_matchesEntity(self, spec, actual, key_to_compare): + expected_entity = self.test.entity_map[spec] + self.test.assertEqual(expected_entity, actual[key_to_compare]) + + def _operation_matchesHexBytes(self, spec, actual, key_to_compare): + expected = binascii.unhexlify(spec) + value = actual[key_to_compare] if key_to_compare else actual + self.test.assertEqual(value, expected) + + def _operation_unsetOrMatches(self, spec, actual, key_to_compare): + if key_to_compare is None and not actual: + # top-level document can be None when unset + return + + if key_to_compare not in actual: + # we add a dummy value for the compared key to pass map size check + actual[key_to_compare] = "dummyValue" + return + self.match_result(spec, actual[key_to_compare], in_recursive_call=True) + + def _operation_sessionLsid(self, spec, actual, key_to_compare): + expected_lsid = self.test.entity_map.get_lsid_for_session(spec) + self.test.assertEqual(expected_lsid, actual[key_to_compare]) + + def _operation_lte(self, spec, actual, key_to_compare): + if key_to_compare not in actual: + self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}") + self.test.assertLessEqual(actual[key_to_compare], spec) + + def _operation_matchAsDocument(self, spec, actual, key_to_compare): + self._match_document(spec, json_util.loads(actual[key_to_compare]), False) + + def _operation_matchAsRoot(self, spec, actual, key_to_compare): + self._match_document(spec, actual, True) + + def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): + method_name = "_operation_{}".format(opname.strip("$")) + try: + method = getattr(self, method_name) + except AttributeError: + self.test.fail(f"Unsupported special matching operator {opname}") + else: + method(spec, actual, key_to_compare) + + def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): + """Returns True if a special operation is evaluated, False + otherwise. If the ``expectation`` map contains a single key, + value pair we check it for a special operation. + If given, ``key_to_compare`` is assumed to be the key in + ``expectation`` whose corresponding value needs to be + evaluated for a possible special operation. ``key_to_compare`` + is ignored when ``expectation`` has only one key. + """ + if not isinstance(expectation, abc.Mapping): + return False + + is_special_op, opname, spec = False, False, False + + if key_to_compare is not None: + if key_to_compare.startswith("$$"): + is_special_op = True + opname = key_to_compare + spec = expectation[key_to_compare] + key_to_compare = None + else: + nested = expectation[key_to_compare] + if isinstance(nested, abc.Mapping) and len(nested) == 1: + opname, spec = next(iter(nested.items())) + if opname.startswith("$$"): + is_special_op = True + elif len(expectation) == 1: + opname, spec = next(iter(expectation.items())) + if opname.startswith("$$"): + is_special_op = True + key_to_compare = None + + if is_special_op: + self._evaluate_special_operation( + opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare + ) + return True + + return False + + def _match_document(self, expectation, actual, is_root, test=False): + if self._evaluate_if_special_operation(expectation, actual): + return + + self.test.assertIsInstance(actual, abc.Mapping) + for key, value in expectation.items(): + if self._evaluate_if_special_operation(expectation, actual, key): + continue + + self.test.assertIn(key, actual) + if not self.match_result(value, actual[key], in_recursive_call=True, test=test): + return False + + if not is_root: + expected_keys = set(expectation.keys()) + for key, value in expectation.items(): + if value == {"$$exists": False}: + expected_keys.remove(key) + if test: + self.test.assertEqual(expected_keys, set(actual.keys())) + else: + return set(expected_keys).issubset(set(actual.keys())) + return True + + def match_result(self, expectation, actual, in_recursive_call=False, test=True): + if isinstance(expectation, abc.Mapping): + return self._match_document( + expectation, actual, is_root=not in_recursive_call, test=test + ) + + if isinstance(expectation, abc.MutableSequence): + self.test.assertIsInstance(actual, abc.MutableSequence) + for e, a in zip(expectation, actual): + if isinstance(e, abc.Mapping): + self._match_document(e, a, is_root=not in_recursive_call, test=test) + else: + self.match_result(e, a, in_recursive_call=True, test=test) + return None + + # account for flexible numerics in element-wise comparison + if isinstance(expectation, int) or isinstance(expectation, float): + if test: + self.test.assertEqual(expectation, actual) + else: + return expectation == actual + return None + else: + if test: + self.test.assertIsInstance(actual, type(expectation)) + self.test.assertEqual(expectation, actual) + else: + return isinstance(actual, type(expectation)) and expectation == actual + return None + + def match_server_description(self, actual: ServerDescription, spec: dict) -> None: + for field, expected in spec.items(): + field = camel_to_snake(field) + if field == "type": + field = "server_type_name" + self.test.assertEqual(getattr(actual, field), expected) + + def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None: + for field, expected in spec.items(): + field = camel_to_snake(field) + if field == "type": + field = "topology_type_name" + self.test.assertEqual(getattr(actual, field), expected) + + def match_event_fields(self, actual: Any, spec: dict) -> None: + for field, expected in spec.items(): + if field == "command" and isinstance(actual, CommandStartedEvent): + command = spec["command"] + if command: + self.match_result(command, actual.command) + continue + if field == "reply" and isinstance(actual, CommandSucceededEvent): + reply = spec["reply"] + if reply: + self.match_result(reply, actual.reply) + continue + if field == "hasServiceId": + if spec["hasServiceId"]: + self.test.assertIsNotNone(actual.service_id) + self.test.assertIsInstance(actual.service_id, ObjectId) + else: + self.test.assertIsNone(actual.service_id) + continue + if field == "hasServerConnectionId": + if spec["hasServerConnectionId"]: + self.test.assertIsNotNone(actual.server_connection_id) + self.test.assertIsInstance(actual.server_connection_id, int) + else: + self.test.assertIsNone(actual.server_connection_id) + continue + if field in ("previousDescription", "newDescription"): + if isinstance(actual, ServerDescriptionChangedEvent): + self.match_server_description( + getattr(actual, camel_to_snake(field)), spec[field] + ) + continue + if isinstance(actual, TopologyDescriptionChangedEvent): + self.match_topology_description( + getattr(actual, camel_to_snake(field)), spec[field] + ) + continue + + if field == "interruptInUseConnections": + field = "interrupt_connections" + else: + field = camel_to_snake(field) + self.test.assertEqual(getattr(actual, field), expected) + + def match_event(self, expectation, actual): + name, spec = next(iter(expectation.items())) + if name == "commandStartedEvent": + self.test.assertIsInstance(actual, CommandStartedEvent) + elif name == "commandSucceededEvent": + self.test.assertIsInstance(actual, CommandSucceededEvent) + elif name == "commandFailedEvent": + self.test.assertIsInstance(actual, CommandFailedEvent) + elif name == "poolCreatedEvent": + self.test.assertIsInstance(actual, PoolCreatedEvent) + elif name == "poolReadyEvent": + self.test.assertIsInstance(actual, PoolReadyEvent) + elif name == "poolClearedEvent": + self.test.assertIsInstance(actual, PoolClearedEvent) + self.test.assertIsInstance(actual.interrupt_connections, bool) + elif name == "poolClosedEvent": + self.test.assertIsInstance(actual, PoolClosedEvent) + elif name == "connectionCreatedEvent": + self.test.assertIsInstance(actual, ConnectionCreatedEvent) + elif name == "connectionReadyEvent": + self.test.assertIsInstance(actual, ConnectionReadyEvent) + elif name == "connectionClosedEvent": + self.test.assertIsInstance(actual, ConnectionClosedEvent) + elif name == "connectionCheckOutStartedEvent": + self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) + elif name == "connectionCheckOutFailedEvent": + self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) + elif name == "connectionCheckedOutEvent": + self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) + elif name == "connectionCheckedInEvent": + self.test.assertIsInstance(actual, ConnectionCheckedInEvent) + elif name == "serverDescriptionChangedEvent": + self.test.assertIsInstance(actual, ServerDescriptionChangedEvent) + elif name == "serverHeartbeatStartedEvent": + self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent) + elif name == "serverHeartbeatSucceededEvent": + self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent) + elif name == "serverHeartbeatFailedEvent": + self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent) + elif name == "topologyDescriptionChangedEvent": + self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent) + elif name == "topologyOpeningEvent": + self.test.assertIsInstance(actual, TopologyOpenedEvent) + elif name == "topologyClosedEvent": + self.test.assertIsInstance(actual, TopologyClosedEvent) + else: + raise Exception(f"Unsupported event type {name}") + + self.match_event_fields(actual, spec) + + +def coerce_result(opname, result): + """Convert a pymongo result into the spec's result format.""" + if hasattr(result, "acknowledged") and not result.acknowledged: + return {"acknowledged": False} + if opname == "bulkWrite": + return parse_bulk_write_result(result) + if opname == "clientBulkWrite": + return parse_client_bulk_write_result(result) + if opname == "insertOne": + return {"insertedId": result.inserted_id} + if opname == "insertMany": + return dict(enumerate(result.inserted_ids)) + if opname in ("deleteOne", "deleteMany"): + return {"deletedCount": result.deleted_count} + if opname in ("updateOne", "updateMany", "replaceOne"): + value = { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": 0 if result.upserted_id is None else 1, + } + if result.upserted_id is not None: + value["upsertedId"] = result.upserted_id + return value + return result diff --git a/tools/synchro.py b/tools/synchro.py index f704919a17..e0af5efa44 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -205,6 +205,7 @@ def async_only_test(f: str) -> bool: "test_retryable_writes.py", "test_session.py", "test_transactions.py", + "unified_format.py", ] sync_test_files = [ From 7e86d24c7bffe4da0a4d32580b5da0e6230b78d2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 13:59:37 -0400 Subject: [PATCH 04/35] PYTHON-4849 - Convert test.test_connection_logging.py to async (#1918) --- test/asynchronous/test_connection_logging.py | 45 ++++++++++++++++++++ test/test_connection_logging.py | 8 +++- tools/synchro.py | 1 + 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_connection_logging.py diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py new file mode 100644 index 0000000000..6bc9835b70 --- /dev/null +++ b/test/asynchronous/test_connection_logging.py @@ -0,0 +1,45 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the connection logging unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_connection_logging.py b/test/test_connection_logging.py index 262ce821eb..253193cc43 100644 --- a/test/test_connection_logging.py +++ b/test/test_connection_logging.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,13 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_logging") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") globals().update( diff --git a/tools/synchro.py b/tools/synchro.py index e0af5efa44..dbaf0a15e9 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -193,6 +193,7 @@ def async_only_test(f: str) -> bool: "test_collation.py", "test_collection.py", "test_common.py", + "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", "test_cursor.py", "test_database.py", From e0fde2338126ee3e8ca7771b3f88c4a2706638f2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 13:59:44 -0400 Subject: [PATCH 05/35] PYTHON-4850 - Convert test.test_crud_unified to async (#1920) --- test/asynchronous/test_crud_unified.py | 39 ++++++++++++++++++++++++++ test/test_crud_unified.py | 10 +++++-- tools/synchro.py | 1 + 3 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 test/asynchronous/test_crud_unified.py diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py new file mode 100644 index 0000000000..3d8deb36e9 --- /dev/null +++ b/test/asynchronous/test_crud_unified.py @@ -0,0 +1,39 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the CRUD unified spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index 92a60a47fc..26f34cba88 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,11 +24,16 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "unified") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) if __name__ == "__main__": unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index dbaf0a15e9..39ce7fbdd0 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -195,6 +195,7 @@ def async_only_test(f: str) -> bool: "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", + "test_crud_unified.py", "test_cursor.py", "test_database.py", "test_encryption.py", From b2332b2aaeb26ecd7efa4992f037ca4dc56583db Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 13:59:49 -0400 Subject: [PATCH 06/35] PYTHON-4846 - Convert test.test_command_logging.py to async (#1915) --- test/asynchronous/test_command_logging.py | 44 +++++++++++++++++++++++ test/test_command_logging.py | 9 ++++- tools/synchro.py | 1 + 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_command_logging.py diff --git a/test/asynchronous/test_command_logging.py b/test/asynchronous/test_command_logging.py new file mode 100644 index 0000000000..f9b459c152 --- /dev/null +++ b/test/asynchronous/test_command_logging.py @@ -0,0 +1,44 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the command monitoring unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_command_logging.py b/test/test_command_logging.py index 9b2d52e66b..cf865920ca 100644 --- a/test/test_command_logging.py +++ b/test/test_command_logging.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,14 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_logging") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") + globals().update( generate_test_classes( diff --git a/tools/synchro.py b/tools/synchro.py index 39ce7fbdd0..f40a64e4c2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -192,6 +192,7 @@ def async_only_test(f: str) -> bool: "test_client_context.py", "test_collation.py", "test_collection.py", + "test_command_logging.py", "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", From 4eeaa4b7be9e814fd207166904f42556c10ce63b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 14:56:43 -0400 Subject: [PATCH 07/35] PYTHON-4848 - Convert test.test_command_monitoring.py to async (#1917) --- test/asynchronous/test_command_monitoring.py | 45 ++++++++++++++++++++ test/test_command_monitoring.py | 8 +++- tools/synchro.py | 1 + 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_command_monitoring.py diff --git a/test/asynchronous/test_command_monitoring.py b/test/asynchronous/test_command_monitoring.py new file mode 100644 index 0000000000..311fd1fdc1 --- /dev/null +++ b/test/asynchronous/test_command_monitoring.py @@ -0,0 +1,45 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the command monitoring unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_command_monitoring.py b/test/test_command_monitoring.py index d2f578824d..4f5ef06f28 100644 --- a/test/test_command_monitoring.py +++ b/test/test_command_monitoring.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,13 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") globals().update( diff --git a/tools/synchro.py b/tools/synchro.py index f40a64e4c2..b6812e9be6 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -193,6 +193,7 @@ def async_only_test(f: str) -> bool: "test_collation.py", "test_collection.py", "test_command_logging.py", + "test_command_monitoring.py", "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", From 33163ecc0d4fe7dc8f7bfc12ef93d89513203fe2 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:02:13 -0700 Subject: [PATCH 08/35] PYTHON-4804 Migrate test_comment.py to async (#1887) --- test/asynchronous/test_comment.py | 159 ++++++++++++++++++++++++++++++ test/test_comment.py | 60 ++++------- tools/synchro.py | 2 + 3 files changed, 179 insertions(+), 42 deletions(-) create mode 100644 test/asynchronous/test_comment.py diff --git a/test/asynchronous/test_comment.py b/test/asynchronous/test_comment.py new file mode 100644 index 0000000000..be3626a8b8 --- /dev/null +++ b/test/asynchronous/test_comment.py @@ -0,0 +1,159 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the keyword argument 'comment' in various helpers.""" + +from __future__ import annotations + +import inspect +import sys + +sys.path[0:0] = [""] +from asyncio import iscoroutinefunction +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import OvertCommandListener + +from bson.dbref import DBRef +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.operations import IndexModel + +_IS_SYNC = False + + +class AsyncTestComment(AsyncIntegrationTest): + async def _test_ops( + self, + helpers, + already_supported, + listener, + ): + for h, args in helpers: + c = "testing comment with " + h.__name__ + with self.subTest("collection-" + h.__name__ + "-comment"): + for cc in [c, {"key": c}, ["any", 1]]: + listener.reset() + kwargs = {"comment": cc} + try: + maybe_cursor = await h(*args, **kwargs) + except Exception: + maybe_cursor = None + self.assertIn( + "comment", + inspect.signature(h).parameters, + msg="Could not find 'comment' in the " + "signature of function %s" % (h.__name__), + ) + self.assertEqual( + inspect.signature(h).parameters["comment"].annotation, "Optional[Any]" + ) + if isinstance(maybe_cursor, AsyncCommandCursor): + await maybe_cursor.close() + + cmd = listener.started_events[0] + self.assertEqual(cc, cmd.command.get("comment"), msg=cmd) + + if h.__name__ != "aggregate_raw_batches": + self.assertIn( + ":param comment:", + h.__doc__, + ) + if h not in already_supported: + self.assertIn( + "Added ``comment`` parameter", + h.__doc__, + ) + else: + self.assertNotIn( + "Added ``comment`` parameter", + h.__doc__, + ) + + listener.reset() + + @async_client_context.require_version_min(4, 7, -1) + @async_client_context.require_replica_set + async def test_database_helpers(self): + listener = OvertCommandListener() + db = (await self.async_rs_or_single_client(event_listeners=[listener])).db + helpers = [ + (db.watch, []), + (db.command, ["hello"]), + (db.list_collections, []), + (db.list_collection_names, []), + (db.drop_collection, ["hello"]), + (db.validate_collection, ["test"]), + (db.dereference, [DBRef("collection", 1)]), + ] + already_supported = [db.command, db.list_collections, db.list_collection_names] + await self._test_ops(helpers, already_supported, listener) + + @async_client_context.require_version_min(4, 7, -1) + @async_client_context.require_replica_set + async def test_client_helpers(self): + listener = OvertCommandListener() + cli = await self.async_rs_or_single_client(event_listeners=[listener]) + helpers = [ + (cli.watch, []), + (cli.list_databases, []), + (cli.list_database_names, []), + (cli.drop_database, ["test"]), + ] + already_supported = [ + cli.list_databases, + ] + await self._test_ops(helpers, already_supported, listener) + + @async_client_context.require_version_min(4, 7, -1) + async def test_collection_helpers(self): + listener = OvertCommandListener() + db = (await self.async_rs_or_single_client(event_listeners=[listener]))[self.db.name] + coll = db.get_collection("test") + + helpers = [ + (coll.list_indexes, []), + (coll.drop, []), + (coll.index_information, []), + (coll.options, []), + (coll.aggregate, [[{"$set": {"x": 1}}]]), + (coll.aggregate_raw_batches, [[{"$set": {"x": 1}}]]), + (coll.rename, ["temp_temp_temp"]), + (coll.distinct, ["_id"]), + (coll.find_one_and_delete, [{}]), + (coll.find_one_and_replace, [{}, {}]), + (coll.find_one_and_update, [{}, {"$set": {"a": 1}}]), + (coll.estimated_document_count, []), + (coll.count_documents, [{}]), + (coll.create_indexes, [[IndexModel("a")]]), + (coll.create_index, ["a"]), + (coll.drop_index, [[("a", 1)]]), + (coll.drop_indexes, []), + ] + already_supported = [ + coll.estimated_document_count, + coll.count_documents, + coll.create_indexes, + coll.drop_indexes, + coll.options, + coll.find_one_and_replace, + coll.drop_index, + coll.rename, + coll.distinct, + coll.find_one_and_delete, + coll.find_one_and_update, + ] + await self._test_ops(helpers, already_supported, listener) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_comment.py b/test/test_comment.py index c0f037ea44..9f9bf98640 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -20,24 +20,15 @@ import sys sys.path[0:0] = [""] - +from asyncio import iscoroutinefunction from test import IntegrationTest, client_context, unittest -from test.utils import EventListener +from test.utils import OvertCommandListener from bson.dbref import DBRef from pymongo.operations import IndexModel from pymongo.synchronous.command_cursor import CommandCursor - -class Empty: - def __getattr__(self, item): - try: - self.__dict__[item] - except KeyError: - return self.empty - - def empty(self, *args, **kwargs): - return Empty() +_IS_SYNC = True class TestComment(IntegrationTest): @@ -46,8 +37,6 @@ def _test_ops( helpers, already_supported, listener, - db=Empty(), # noqa: B008 - coll=Empty(), # noqa: B008 ): for h, args in helpers: c = "testing comment with " + h.__name__ @@ -55,19 +44,10 @@ def _test_ops( for cc in [c, {"key": c}, ["any", 1]]: listener.reset() kwargs = {"comment": cc} - if h == coll.rename: - _ = db.get_collection("temp_temp_temp").drop() - destruct_coll = db.get_collection("test_temp") - destruct_coll.insert_one({}) - maybe_cursor = destruct_coll.rename(*args, **kwargs) - destruct_coll.drop() - elif h == db.validate_collection: - coll = db.get_collection("test") - coll.insert_one({}) - maybe_cursor = db.validate_collection(*args, **kwargs) - else: - coll.create_index("a") + try: maybe_cursor = h(*args, **kwargs) + except Exception: + maybe_cursor = None self.assertIn( "comment", inspect.signature(h).parameters, @@ -79,15 +59,11 @@ def _test_ops( ) if isinstance(maybe_cursor, CommandCursor): maybe_cursor.close() - tested = False - # For some reason collection.list_indexes creates two commands and the first - # one doesn't contain 'comment'. - for i in listener.started_events: - if cc == i.command.get("comment", ""): - self.assertEqual(cc, i.command["comment"]) - tested = True - self.assertTrue(tested) - if h not in [coll.aggregate_raw_batches]: + + cmd = listener.started_events[0] + self.assertEqual(cc, cmd.command.get("comment"), msg=cmd) + + if h.__name__ != "aggregate_raw_batches": self.assertIn( ":param comment:", h.__doc__, @@ -108,8 +84,8 @@ def _test_ops( @client_context.require_version_min(4, 7, -1) @client_context.require_replica_set def test_database_helpers(self): - listener = EventListener() - db = self.rs_or_single_client(event_listeners=[listener]).db + listener = OvertCommandListener() + db = (self.rs_or_single_client(event_listeners=[listener])).db helpers = [ (db.watch, []), (db.command, ["hello"]), @@ -120,12 +96,12 @@ def test_database_helpers(self): (db.dereference, [DBRef("collection", 1)]), ] already_supported = [db.command, db.list_collections, db.list_collection_names] - self._test_ops(helpers, already_supported, listener, db=db, coll=db.get_collection("test")) + self._test_ops(helpers, already_supported, listener) @client_context.require_version_min(4, 7, -1) @client_context.require_replica_set def test_client_helpers(self): - listener = EventListener() + listener = OvertCommandListener() cli = self.rs_or_single_client(event_listeners=[listener]) helpers = [ (cli.watch, []), @@ -140,8 +116,8 @@ def test_client_helpers(self): @client_context.require_version_min(4, 7, -1) def test_collection_helpers(self): - listener = EventListener() - db = self.rs_or_single_client(event_listeners=[listener])[self.db.name] + listener = OvertCommandListener() + db = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name] coll = db.get_collection("test") helpers = [ @@ -176,7 +152,7 @@ def test_collection_helpers(self): coll.find_one_and_delete, coll.find_one_and_update, ] - self._test_ops(helpers, already_supported, listener, coll=coll, db=db) + self._test_ops(helpers, already_supported, listener) if __name__ == "__main__": diff --git a/tools/synchro.py b/tools/synchro.py index b6812e9be6..25f506ed5a 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -193,7 +193,9 @@ def async_only_test(f: str) -> bool: "test_collation.py", "test_collection.py", "test_command_logging.py", + "test_command_logging.py", "test_command_monitoring.py", + "test_comment.py", "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", From 3c5e71a1cb28b695bc2eec4c3927ef6af56835a8 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 14 Oct 2024 07:32:38 -0500 Subject: [PATCH 09/35] PYTHON-4862 Fix handling of interrupt_loop in unified test runner (#1924) --- test/asynchronous/unified_format.py | 8 +++++++- test/unified_format.py | 8 +++++++- test/unified_format_shared.py | 5 ----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 4c37422951..42bda59cb2 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -36,7 +36,6 @@ unittest, ) from test.unified_format_shared import ( - IS_INTERRUPTED, KMS_TLS_OPTS, PLACEHOLDER_MAP, SKIP_CSOT_TESTS, @@ -104,6 +103,13 @@ _IS_SYNC = False +IS_INTERRUPTED = False + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + async def is_run_on_requirement_satisfied(requirement): topology_satisfied = True diff --git a/test/unified_format.py b/test/unified_format.py index 6a19082b86..13ab0af69b 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -36,7 +36,6 @@ unittest, ) from test.unified_format_shared import ( - IS_INTERRUPTED, KMS_TLS_OPTS, PLACEHOLDER_MAP, SKIP_CSOT_TESTS, @@ -104,6 +103,13 @@ _IS_SYNC = True +IS_INTERRUPTED = False + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + def is_run_on_requirement_satisfied(requirement): topology_satisfied = True diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py index d11624476d..f1b908a7a6 100644 --- a/test/unified_format_shared.py +++ b/test/unified_format_shared.py @@ -139,11 +139,6 @@ } -def interrupt_loop(): - global IS_INTERRUPTED - IS_INTERRUPTED = True - - def with_metaclass(meta, *bases): """Create a base class with a metaclass. From 9ba780cac256720be5c3c5051c7f8a19d27693d5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 14 Oct 2024 07:34:01 -0500 Subject: [PATCH 10/35] PYTHON-4861 Ensure hatch is isolated in Evergreen (#1923) --- .evergreen/hatch.sh | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh index db0da2f4d0..8f862c39d2 100644 --- a/.evergreen/hatch.sh +++ b/.evergreen/hatch.sh @@ -18,17 +18,22 @@ if [ -n "$SKIP_HATCH" ]; then run_hatch() { bash ./.evergreen/run-tests.sh } -elif $PYTHON_BINARY -m hatch --version; then - run_hatch() { - $PYTHON_BINARY -m hatch run "$@" - } -else # No toolchain hatch present, set up virtualenv before installing hatch +else # Set up virtualenv before installing hatch # Use a random venv name because the encryption tasks run this script multiple times in the same run. ENV_NAME=hatchenv-$RANDOM createvirtualenv "$PYTHON_BINARY" $ENV_NAME # shellcheck disable=SC2064 trap "deactivate; rm -rf $ENV_NAME" EXIT HUP python -m pip install -q hatch + + # Ensure hatch does not write to user or global locations. + touch hatch_config.toml + HATCH_CONFIG=$(pwd)/hatch_config.toml + export HATCH_CONFIG + hatch config restore + hatch config set dirs.data ".hatch/data" + hatch config set dirs.cache ".hatch/cache" + run_hatch() { python -m hatch run "$@" } From 3cc722e9105d5818d57739d623d985d69b0eb626 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 14 Oct 2024 14:05:22 -0500 Subject: [PATCH 11/35] PYTHON-4838 Generate OCSP build variants using shrub.py (#1910) --- .evergreen/config.yml | 174 +++++++++++++++++++++----- .evergreen/scripts/generate_config.py | 167 ++++++++++++++++++++++++ 2 files changed, 308 insertions(+), 33 deletions(-) create mode 100644 .evergreen/scripts/generate_config.py diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 1ef8751501..dee4b608ec 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2826,42 +2826,150 @@ buildvariants: - "test-6.0-standalone" - "test-5.0-standalone" -- matrix_name: "ocsp-test" - matrix_spec: - platform: rhel8 - python-version: ["3.9", "3.10", "pypy3.9", "pypy3.10"] - mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] - auth: "noauth" - ssl: "ssl" - display_name: "OCSP test ${platform} ${python-version} ${mongodb-version}" - batchtime: 20160 # 14 days +# OCSP test matrix. +- name: ocsp-test-rhel8-v4.4-py3.9 tasks: - - name: ".ocsp" - -- matrix_name: "ocsp-test-windows" - matrix_spec: - platform: windows - python-version-windows: ["3.9", "3.10"] - mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] - auth: "noauth" - ssl: "ssl" - display_name: "OCSP test ${platform} ${python-version-windows} ${mongodb-version}" - batchtime: 20160 # 14 days + - name: .ocsp + display_name: OCSP test RHEL8 v4.4 py3.9 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.4" + PYTHON_BINARY: /opt/python/3.9/bin/python3 +- name: ocsp-test-rhel8-v5.0-py3.10 tasks: - # Windows MongoDB servers do not staple OCSP responses and only support RSA. - - name: ".ocsp-rsa !.ocsp-staple" - -- matrix_name: "ocsp-test-macos" - matrix_spec: - platform: macos - mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] - auth: "noauth" - ssl: "ssl" - display_name: "OCSP test ${platform} ${mongodb-version}" - batchtime: 20160 # 14 days + - name: .ocsp + display_name: OCSP test RHEL8 v5.0 py3.10 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "5.0" + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: ocsp-test-rhel8-v6.0-py3.11 tasks: - # macOS MongoDB servers do not staple OCSP responses and only support RSA. - - name: ".ocsp-rsa !.ocsp-staple" + - name: .ocsp + display_name: OCSP test RHEL8 v6.0 py3.11 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "6.0" + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: ocsp-test-rhel8-v7.0-py3.12 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 v7.0 py3.12 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "7.0" + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: ocsp-test-rhel8-v8.0-py3.13 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 v8.0 py3.13 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "8.0" + PYTHON_BINARY: /opt/python/3.13/bin/python3 +- name: ocsp-test-rhel8-rapid-pypy3.9 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 rapid pypy3.9 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: rapid + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: ocsp-test-rhel8-latest-pypy3.10 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 latest pypy3.10 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: latest + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 +- name: ocsp-test-win64-v4.4-py3.9 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test Win64 v4.4 py3.9 + run_on: + - windows-64-vsMulti-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.4" + PYTHON_BINARY: C:/python/Python39/python.exe +- name: ocsp-test-win64-v8.0-py3.13 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test Win64 v8.0 py3.13 + run_on: + - windows-64-vsMulti-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "8.0" + PYTHON_BINARY: C:/python/Python313/python.exe +- name: ocsp-test-macos-v4.4-py3.9 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test macOS v4.4 py3.9 + run_on: + - macos-14 + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.4" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 +- name: ocsp-test-macos-v8.0-py3.13 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test macOS v8.0 py3.13 + run_on: + - macos-14 + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "8.0" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 - matrix_name: "oidc-auth-test" matrix_spec: diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py new file mode 100644 index 0000000000..e98e527b72 --- /dev/null +++ b/.evergreen/scripts/generate_config.py @@ -0,0 +1,167 @@ +# /// script +# requires-python = ">=3.9" +# dependencies = [ +# "shrub.py>=3.2.0", +# "pyyaml>=6.0.2" +# ] +# /// + +# Note: Run this file with `hatch run`, `pipx run`, or `uv run`. +from __future__ import annotations + +from dataclasses import dataclass +from itertools import cycle, product, zip_longest +from typing import Any + +from shrub.v3.evg_build_variant import BuildVariant +from shrub.v3.evg_project import EvgProject +from shrub.v3.evg_task import EvgTaskRef +from shrub.v3.shrub_service import ShrubService + +############## +# Globals +############## + +ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] +CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] +PYPYS = ["pypy3.9", "pypy3.10"] +ALL_PYTHONS = CPYTHONS + PYPYS +BATCHTIME_WEEK = 10080 +HOSTS = dict() + + +@dataclass +class Host: + name: str + run_on: str + display_name: str + + +HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8") +HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64") +HOSTS["macos"] = Host("macos", "macos-14", "macOS") + + +############## +# Helpers +############## + + +def create_variant( + task_names: list[str], + display_name: str, + *, + python: str | None = None, + version: str | None = None, + host: str | None = None, + **kwargs: Any, +) -> BuildVariant: + """Create a build variant for the given inputs.""" + task_refs = [EvgTaskRef(name=n) for n in task_names] + kwargs.setdefault("expansions", dict()) + expansions = kwargs.pop("expansions", dict()).copy() + host = host or "rhel8" + run_on = [HOSTS[host].run_on] + name = display_name.replace(" ", "-").lower() + if python: + expansions["PYTHON_BINARY"] = get_python_binary(python, host) + if version: + expansions["VERSION"] = version + expansions = expansions or None + return BuildVariant( + name=name, + display_name=display_name, + tasks=task_refs, + expansions=expansions, + run_on=run_on, + **kwargs, + ) + + +def get_python_binary(python: str, host: str) -> str: + """Get the appropriate python binary given a python version and host.""" + if host == "win64": + is_32 = python.startswith("32-bit") + if is_32: + _, python = python.split() + base = "C:/python/32" + else: + base = "C:/python" + python = python.replace(".", "") + return f"{base}/Python{python}/python.exe" + + if host == "rhel8": + return f"/opt/python/{python}/bin/python3" + + if host == "macos": + return f"/Library/Frameworks/Python.Framework/Versions/{python}/bin/python3" + + raise ValueError(f"no match found for python {python} on {host}") + + +def get_display_name(base: str, host: str, version: str, python: str) -> str: + """Get the display name of a variant.""" + if version not in ["rapid", "latest"]: + version = f"v{version}" + if not python.startswith("pypy"): + python = f"py{python}" + return f"{base} {HOSTS[host].display_name} {version} {python}" + + +def zip_cycle(*iterables, empty_default=None): + """Get all combinations of the inputs, cycling over the shorter list(s).""" + cycles = [cycle(i) for i in iterables] + for _ in zip_longest(*iterables): + yield tuple(next(i, empty_default) for i in cycles) + + +############## +# Variants +############## + + +def create_ocsp_variants() -> list[BuildVariant]: + variants = [] + batchtime = BATCHTIME_WEEK * 2 + expansions = dict(AUTH="noauth", SSL="ssl", TOPOLOGY="server") + base_display = "OCSP test" + + # OCSP tests on rhel8 with all servers v4.4+ and all python versions. + versions = [v for v in ALL_VERSIONS if v != "4.0"] + for version, python in zip_cycle(versions, ALL_PYTHONS): + host = "rhel8" + variant = create_variant( + [".ocsp"], + get_display_name(base_display, host, version, python), + python=python, + version=version, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + variants.append(variant) + + # OCSP tests on Windows and MacOS. + # MongoDB servers on these hosts do not staple OCSP responses and only support RSA. + for host, version in product(["win64", "macos"], ["4.4", "8.0"]): + python = CPYTHONS[0] if version == "4.4" else CPYTHONS[-1] + variant = create_variant( + [".ocsp-rsa !.ocsp-staple"], + get_display_name(base_display, host, version, python), + python=python, + version=version, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + variants.append(variant) + + return variants + + +################## +# Generate Config +################## + +project = EvgProject(tasks=None, buildvariants=create_ocsp_variants()) +print(ShrubService.generate_yaml(project)) # noqa: T201 From a911245bde1377c485f06dfd5373d159b7e8aff7 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 14 Oct 2024 15:06:42 -0700 Subject: [PATCH 12/35] PYTHON-4866 Fix test_command_cursor_to_list_csot_applied (#1926) --- test/asynchronous/test_cursor.py | 14 ++++++-------- test/test_cursor.py | 14 ++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index e79ad00641..ee0a757ed3 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1412,12 +1412,11 @@ async def test_to_list_length(self): self.assertEqual(len(docs), 2) async def test_to_list_csot_applied(self): - client = await self.async_single_client(timeoutMS=500) + client = await self.async_single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - await client.admin.command("ping") - coll = client.pymongo.test - await coll.insert_many([{} for _ in range(5)]) + await coll.insert_many([{} for _ in range(5)]) cursor = coll.find({"$where": delay(1)}) with self.assertRaises(PyMongoError) as ctx: await cursor.to_list() @@ -1454,12 +1453,11 @@ async def test_command_cursor_to_list_length(self): @async_client_context.require_failCommand_blockConnection async def test_command_cursor_to_list_csot_applied(self): - client = await self.async_single_client(timeoutMS=500) + client = await self.async_single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - await client.admin.command("ping") - coll = client.pymongo.test - await coll.insert_many([{} for _ in range(5)]) + await coll.insert_many([{} for _ in range(5)]) fail_command = { "configureFailPoint": "failCommand", "mode": {"times": 5}, diff --git a/test/test_cursor.py b/test/test_cursor.py index 7c073bf351..7a6dfc9429 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1403,12 +1403,11 @@ def test_to_list_length(self): self.assertEqual(len(docs), 2) def test_to_list_csot_applied(self): - client = self.single_client(timeoutMS=500) + client = self.single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - client.admin.command("ping") - coll = client.pymongo.test - coll.insert_many([{} for _ in range(5)]) + coll.insert_many([{} for _ in range(5)]) cursor = coll.find({"$where": delay(1)}) with self.assertRaises(PyMongoError) as ctx: cursor.to_list() @@ -1445,12 +1444,11 @@ def test_command_cursor_to_list_length(self): @client_context.require_failCommand_blockConnection def test_command_cursor_to_list_csot_applied(self): - client = self.single_client(timeoutMS=500) + client = self.single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - client.admin.command("ping") - coll = client.pymongo.test - coll.insert_many([{} for _ in range(5)]) + coll.insert_many([{} for _ in range(5)]) fail_command = { "configureFailPoint": "failCommand", "mode": {"times": 5}, From 9e38c54fa03d0f719a43ff023894c2a1ad9b5480 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 14 Oct 2024 15:25:21 -0700 Subject: [PATCH 13/35] PYTHON-4861 Fix HATCH_CONFIG on cygwin (#1927) --- .evergreen/hatch.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh index 8f862c39d2..6f3d36b389 100644 --- a/.evergreen/hatch.sh +++ b/.evergreen/hatch.sh @@ -29,6 +29,9 @@ else # Set up virtualenv before installing hatch # Ensure hatch does not write to user or global locations. touch hatch_config.toml HATCH_CONFIG=$(pwd)/hatch_config.toml + if [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin + HATCH_CONFIG=$(cygpath -m "$HATCH_CONFIG") + fi export HATCH_CONFIG hatch config restore hatch config set dirs.data ".hatch/data" From 872fda179e247fb8e1bcc3cf2af3d892788a2e2f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 08:54:42 -0400 Subject: [PATCH 14/35] PYTHON-4574 - FaaS detection logic mistakenly identifies EKS as AWS Lambda (#1908) --- test/asynchronous/test_client.py | 16 ++++++++++++++++ test/test_client.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index faa23348c9..c6b6416c16 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2019,6 +2019,22 @@ async def test_handshake_08_invalid_aws_ec2(self): None, ) + async def test_handshake_09_container_with_provider(self): + await self._test_handshake( + { + ENV_VAR_K8S: "1", + "AWS_LAMBDA_RUNTIME_API": "1", + "AWS_REGION": "us-east-1", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "256", + }, + { + "container": {"orchestrator": "kubernetes"}, + "name": "aws.lambda", + "region": "us-east-1", + "memory_mb": 256, + }, + ) + def test_dict_hints(self): self.db.t.find(hint={"x": 1}) diff --git a/test/test_client.py b/test/test_client.py index be1994dd93..8e3d9c8b8b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1977,6 +1977,22 @@ def test_handshake_08_invalid_aws_ec2(self): None, ) + def test_handshake_09_container_with_provider(self): + self._test_handshake( + { + ENV_VAR_K8S: "1", + "AWS_LAMBDA_RUNTIME_API": "1", + "AWS_REGION": "us-east-1", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "256", + }, + { + "container": {"orchestrator": "kubernetes"}, + "name": "aws.lambda", + "region": "us-east-1", + "memory_mb": 256, + }, + ) + def test_dict_hints(self): self.db.t.find(hint={"x": 1}) From 710bc40c730d2fd982e1cb7a41fd91ac7b5d4498 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 12:12:18 -0400 Subject: [PATCH 15/35] =?UTF-8?q?PYTHON-4870=20-=20MongoClient.address=20s?= =?UTF-8?q?hould=20block=20until=20a=20connection=20suc=E2=80=A6=20(#1929)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pymongo/asynchronous/mongo_client.py | 7 ------- pymongo/synchronous/mongo_client.py | 7 ------- test/asynchronous/test_client.py | 2 -- test/test_client.py | 2 -- test/test_replica_set_reconfig.py | 3 ++- 5 files changed, 2 insertions(+), 19 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index bfae302dac..4e09efe401 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1453,13 +1453,6 @@ async def address(self) -> Optional[tuple[str, int]]: 'Cannot use "address" property when load balancing among' ' mongoses, use "nodes" instead.' ) - if topology_type not in ( - TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.Single, - TOPOLOGY_TYPE.LoadBalanced, - TOPOLOGY_TYPE.Sharded, - ): - return None return await self._server_property("address") @property diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1351cb200f..815446bb2c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1447,13 +1447,6 @@ def address(self) -> Optional[tuple[str, int]]: 'Cannot use "address" property when load balancing among' ' mongoses, use "nodes" instead.' ) - if topology_type not in ( - TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.Single, - TOPOLOGY_TYPE.LoadBalanced, - TOPOLOGY_TYPE.Sharded, - ): - return None return self._server_property("address") @property diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index c6b6416c16..590154b857 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -838,8 +838,6 @@ async def test_init_disconnected(self): c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) - self.assertIsNone(await c.address) # PYTHON-2981 - await c.admin.command("ping") # connect if async_client_context.is_rs: # The primary's host and port are from the replica set config. self.assertIsNotNone(await c.address) diff --git a/test/test_client.py b/test/test_client.py index 8e3d9c8b8b..5bbb5bd751 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -812,8 +812,6 @@ def test_init_disconnected(self): c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) - self.assertIsNone(c.address) # PYTHON-2981 - c.admin.command("ping") # connect if client_context.is_rs: # The primary's host and port are from the replica set config. self.assertIsNotNone(c.address) diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 1dae0aea86..4c23d71b69 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -59,7 +59,8 @@ def test_client(self): with self.assertRaises(ServerSelectionTimeoutError): c.db.command("ping") - self.assertEqual(c.address, None) + with self.assertRaises(ServerSelectionTimeoutError): + _ = c.address # Client can still discover the primary node c.revive_host("a:1") From 82e673d6602b968823768d7c99bb5a676c00eb08 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 14:16:19 -0400 Subject: [PATCH 16/35] PYTHON-4870 - Update changelog for MongoClient.address fix (#1931) --- doc/changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index e7b160b176..44d6fc9a57 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -14,6 +14,9 @@ PyMongo 4.11 brings a number of changes including: - Dropped support for MongoDB 3.6. - Added support for free-threaded Python with the GIL disabled. For more information see: `Free-threaded CPython `_. +- :attr:`~pymongo.asynchronous.mongo_client.AsyncMongoClient.address` and + :attr:`~pymongo.mongo_client.MongoClient.address` now correctly block when called on unconnected clients + until either connection succeeds or a server selection timeout error is raised. Issues Resolved ............... From 1b6c0d3a2a7b82f9526f71d5583d11e7674d3c54 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 15 Oct 2024 13:33:04 -0500 Subject: [PATCH 17/35] PYTHON-4868 Generate server tests using shrub.py (#1930) --- .evergreen/config.yml | 548 ++++++++++++++++++++++---- .evergreen/hatch.sh | 4 +- .evergreen/run-tests.sh | 2 +- .evergreen/scripts/generate_config.py | 121 +++++- 4 files changed, 583 insertions(+), 92 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index dee4b608ec..c3427e66d0 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -409,6 +409,7 @@ functions: AUTH=${AUTH} \ SSL=${SSL} \ TEST_DATA_LAKE=${TEST_DATA_LAKE} \ + TEST_SUITES=${TEST_SUITES} \ MONGODB_API_VERSION=${MONGODB_API_VERSION} \ SKIP_HATCH=${SKIP_HATCH} \ bash ${PROJECT_DIRECTORY}/.evergreen/hatch.sh test:test-eg @@ -2399,6 +2400,470 @@ axes: batchtime: 10080 # 7 days buildvariants: +# Server Tests for RHEL8. +- name: test-rhel8-py3.9-auth-ssl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 py3.9 Auth SSL cov + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-py3.9-noauth-ssl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 py3.9 NoAuth SSL cov + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: ssl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-py3.9-noauth-nossl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 py3.9 NoAuth NoSSL cov + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: nossl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-py3.13-auth-ssl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 py3.13 Auth SSL cov + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-py3.13-noauth-ssl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 py3.13 NoAuth SSL cov + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: ssl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-py3.13-noauth-nossl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 py3.13 NoAuth NoSSL cov + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: nossl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-pypy3.10-auth-ssl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 pypy3.10 Auth SSL cov + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-pypy3.10-noauth-ssl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 pypy3.10 NoAuth SSL cov + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: ssl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-pypy3.10-noauth-nossl-cov + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Test RHEL8 pypy3.10 NoAuth NoSSL cov + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: nossl + COVERAGE: coverage + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [coverage_tag] +- name: test-rhel8-py3.10-auth-ssl + tasks: + - name: .standalone + display_name: Test RHEL8 py3.10 Auth SSL + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: test-rhel8-py3.11-noauth-ssl + tasks: + - name: .replica_set + display_name: Test RHEL8 py3.11 NoAuth SSL + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: ssl + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: test-rhel8-py3.12-noauth-nossl + tasks: + - name: .sharded_cluster + display_name: Test RHEL8 py3.12 NoAuth NoSSL + run_on: + - rhel87-small + expansions: + AUTH: noauth + SSL: nossl + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: test-rhel8-pypy3.9-auth-ssl + tasks: + - name: .standalone + display_name: Test RHEL8 pypy3.9 Auth SSL + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 + +# Server tests for MacOS. +- name: test-macos-py3.9-auth-ssl-sync + tasks: + - name: .standalone + display_name: Test macOS py3.9 Auth SSL Sync + run_on: + - macos-14 + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-py3.9-auth-ssl-async + tasks: + - name: .standalone + display_name: Test macOS py3.9 Auth SSL Async + run_on: + - macos-14 + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-py3.13-noauth-ssl-sync + tasks: + - name: .replica_set + display_name: Test macOS py3.13 NoAuth SSL Sync + run_on: + - macos-14 + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-py3.13-noauth-ssl-async + tasks: + - name: .replica_set + display_name: Test macOS py3.13 NoAuth SSL Async + run_on: + - macos-14 + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-py3.9-noauth-nossl-sync + tasks: + - name: .sharded_cluster + display_name: Test macOS py3.9 NoAuth NoSSL Sync + run_on: + - macos-14 + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-py3.9-noauth-nossl-async + tasks: + - name: .sharded_cluster + display_name: Test macOS py3.9 NoAuth NoSSL Async + run_on: + - macos-14 + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default_async + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" + +# Server tests for macOS Arm64. +- name: test-macos-arm64-py3.9-auth-ssl-sync + tasks: + - name: .standalone + display_name: Test macOS Arm64 py3.9 Auth SSL Sync + run_on: + - macos-14-arm64 + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-arm64-py3.9-auth-ssl-async + tasks: + - name: .standalone + display_name: Test macOS Arm64 py3.9 Auth SSL Async + run_on: + - macos-14-arm64 + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-arm64-py3.13-noauth-ssl-sync + tasks: + - name: .replica_set + display_name: Test macOS Arm64 py3.13 NoAuth SSL Sync + run_on: + - macos-14-arm64 + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-arm64-py3.13-noauth-ssl-async + tasks: + - name: .replica_set + display_name: Test macOS Arm64 py3.13 NoAuth SSL Async + run_on: + - macos-14-arm64 + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-arm64-py3.9-noauth-nossl-sync + tasks: + - name: .sharded_cluster + display_name: Test macOS Arm64 py3.9 NoAuth NoSSL Sync + run_on: + - macos-14-arm64 + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" +- name: test-macos-arm64-py3.9-noauth-nossl-async + tasks: + - name: .sharded_cluster + display_name: Test macOS Arm64 py3.9 NoAuth NoSSL Async + run_on: + - macos-14-arm64 + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default_async + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + SKIP_CSOT_TESTS: "true" + +# Server tests for Windows. +- name: test-win64-py3.9-auth-ssl-sync + tasks: + - name: .standalone + display_name: Test Win64 py3.9 Auth SSL Sync + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: C:/python/Python39/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win64-py3.9-auth-ssl-async + tasks: + - name: .standalone + display_name: Test Win64 py3.9 Auth SSL Async + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: C:/python/Python39/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win64-py3.13-noauth-ssl-sync + tasks: + - name: .replica_set + display_name: Test Win64 py3.13 NoAuth SSL Sync + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: C:/python/Python313/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win64-py3.13-noauth-ssl-async + tasks: + - name: .replica_set + display_name: Test Win64 py3.13 NoAuth SSL Async + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: C:/python/Python313/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win64-py3.9-noauth-nossl-sync + tasks: + - name: .sharded_cluster + display_name: Test Win64 py3.9 NoAuth NoSSL Sync + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default + PYTHON_BINARY: C:/python/Python39/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win64-py3.9-noauth-nossl-async + tasks: + - name: .sharded_cluster + display_name: Test Win64 py3.9 NoAuth NoSSL Async + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default_async + PYTHON_BINARY: C:/python/Python39/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win32-py3.9-auth-ssl-sync + tasks: + - name: .standalone + display_name: Test Win32 py3.9 Auth SSL Sync + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: C:/python/32/Python39/python.exe + SKIP_CSOT_TESTS: "true" + +# Server tests for Win32. +- name: test-win32-py3.9-auth-ssl-async + tasks: + - name: .standalone + display_name: Test Win32 py3.9 Auth SSL Async + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: auth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: C:/python/32/Python39/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win32-py3.13-noauth-ssl-sync + tasks: + - name: .replica_set + display_name: Test Win32 py3.13 NoAuth SSL Sync + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default + PYTHON_BINARY: C:/python/32/Python313/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win32-py3.13-noauth-ssl-async + tasks: + - name: .replica_set + display_name: Test Win32 py3.13 NoAuth SSL Async + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: ssl + TEST_SUITES: default_async + PYTHON_BINARY: C:/python/32/Python313/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win32-py3.9-noauth-nossl-sync + tasks: + - name: .sharded_cluster + display_name: Test Win32 py3.9 NoAuth NoSSL Sync + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default + PYTHON_BINARY: C:/python/32/Python39/python.exe + SKIP_CSOT_TESTS: "true" +- name: test-win32-py3.9-noauth-nossl-async + tasks: + - name: .sharded_cluster + display_name: Test Win32 py3.9 NoAuth NoSSL Async + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: noauth + SSL: nossl + TEST_SUITES: default_async + PYTHON_BINARY: C:/python/32/Python39/python.exe + SKIP_CSOT_TESTS: "true" + - matrix_name: "tests-fips" matrix_spec: platform: @@ -2409,44 +2874,6 @@ buildvariants: tasks: - "test-fips-standalone" -- matrix_name: "test-macos" - matrix_spec: - platform: - # MacOS introduced SSL support with MongoDB >= 3.2. - # Older server versions (2.6, 3.0) are supported without SSL. - - macos - auth: "*" - ssl: "*" - exclude_spec: - # No point testing with SSL without auth. - - platform: macos - auth: "noauth" - ssl: "ssl" - display_name: "${platform} ${auth} ${ssl}" - tasks: - - ".latest" - - ".8.0" - - ".7.0" - - ".6.0" - - ".5.0" - - ".4.4" - - ".4.2" - - ".4.0" - -- matrix_name: "test-macos-arm64" - matrix_spec: - platform: - - macos-arm64 - auth-ssl: "*" - display_name: "${platform} ${auth-ssl}" - tasks: - - ".latest" - - ".8.0" - - ".7.0" - - ".6.0" - - ".5.0" - - ".4.4" - - matrix_name: "test-macos-encryption" matrix_spec: platform: @@ -2486,24 +2913,6 @@ buildvariants: tasks: - ".6.0" -- matrix_name: "tests-python-version-rhel8-test-ssl" - matrix_spec: - platform: rhel8 - python-version: "*" - auth-ssl: "*" - coverage: "*" - display_name: "${python-version} ${platform} ${auth-ssl} ${coverage}" - tasks: &all-server-versions - - ".rapid" - - ".latest" - - ".8.0" - - ".7.0" - - ".6.0" - - ".5.0" - - ".4.4" - - ".4.2" - - ".4.0" - - matrix_name: "tests-pyopenssl" matrix_spec: platform: rhel8 @@ -2580,7 +2989,16 @@ buildvariants: auth-ssl: "*" coverage: "*" display_name: "${c-extensions} ${python-version} ${platform} ${auth} ${ssl} ${coverage}" - tasks: *all-server-versions + tasks: &all-server-versions + - ".rapid" + - ".latest" + - ".8.0" + - ".7.0" + - ".6.0" + - ".5.0" + - ".4.4" + - ".4.2" + - ".4.0" - matrix_name: "tests-python-version-rhel8-compression" matrix_spec: @@ -2629,22 +3047,6 @@ buildvariants: display_name: "${green-framework} ${python-version} ${platform} ${auth-ssl}" tasks: *all-server-versions -- matrix_name: "tests-windows-python-version" - matrix_spec: - platform: windows - python-version-windows: "*" - auth-ssl: "*" - display_name: "${platform} ${python-version-windows} ${auth-ssl}" - tasks: *all-server-versions - -- matrix_name: "tests-windows-python-version-32-bit" - matrix_spec: - platform: windows - python-version-windows-32: "*" - auth-ssl: "*" - display_name: "${platform} ${python-version-windows-32} ${auth-ssl}" - tasks: *all-server-versions - - matrix_name: "tests-python-version-supports-openssl-102-test-ssl" matrix_spec: platform: rhel7 diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh index 6f3d36b389..45d5113cd6 100644 --- a/.evergreen/hatch.sh +++ b/.evergreen/hatch.sh @@ -34,8 +34,8 @@ else # Set up virtualenv before installing hatch fi export HATCH_CONFIG hatch config restore - hatch config set dirs.data ".hatch/data" - hatch config set dirs.cache ".hatch/cache" + hatch config set dirs.data "$(pwd)/.hatch/data" + hatch config set dirs.cache "$(pwd)/.hatch/cache" run_hatch() { python -m hatch run "$@" diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 5e8429dd28..364570999f 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -30,7 +30,7 @@ set -o xtrace AUTH=${AUTH:-noauth} SSL=${SSL:-nossl} -TEST_SUITES="" +TEST_SUITES=${TEST_SUITES:-} TEST_ARGS="${*:1}" export PIP_QUIET=1 # Quiet by default diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index e98e527b72..044303ad8f 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -26,7 +26,17 @@ CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] PYPYS = ["pypy3.9", "pypy3.10"] ALL_PYTHONS = CPYTHONS + PYPYS +MIN_MAX_PYTHON = [CPYTHONS[0], CPYTHONS[-1]] BATCHTIME_WEEK = 10080 +AUTH_SSLS = [("auth", "ssl"), ("noauth", "ssl"), ("noauth", "nossl")] +TOPOLOGIES = ["standalone", "replica_set", "sharded_cluster"] +SYNCS = ["sync", "async"] +DISPLAY_LOOKUP = dict( + ssl=dict(ssl="SSL", nossl="NoSSL"), + auth=dict(auth="Auth", noauth="NoAuth"), + test_suites=dict(default="Sync", default_async="Async"), + coverage=dict(coverage="cov"), +) HOSTS = dict() @@ -35,11 +45,18 @@ class Host: name: str run_on: str display_name: str + expansions: dict[str, str] -HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8") -HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64") -HOSTS["macos"] = Host("macos", "macos-14", "macOS") +_macos_expansions = dict( # CSOT tests are unreliable on slow hosts. + SKIP_CSOT_TESTS="true" +) + +HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8", dict()) +HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64", _macos_expansions) +HOSTS["win32"] = Host("win32", "windows-64-vsMulti-small", "Win32", _macos_expansions) +HOSTS["macos"] = Host("macos", "macos-14", "macOS", _macos_expansions) +HOSTS["macos-arm64"] = Host("macos-arm64", "macos-14-arm64", "macOS Arm64", _macos_expansions) ############## @@ -67,6 +84,7 @@ def create_variant( expansions["PYTHON_BINARY"] = get_python_binary(python, host) if version: expansions["VERSION"] = version + expansions.update(HOSTS[host].expansions) expansions = expansions or None return BuildVariant( name=name, @@ -80,10 +98,8 @@ def create_variant( def get_python_binary(python: str, host: str) -> str: """Get the appropriate python binary given a python version and host.""" - if host == "win64": - is_32 = python.startswith("32-bit") - if is_32: - _, python = python.split() + if host in ["win64", "win32"]: + if host == "win32": base = "C:/python/32" else: base = "C:/python" @@ -93,19 +109,29 @@ def get_python_binary(python: str, host: str) -> str: if host == "rhel8": return f"/opt/python/{python}/bin/python3" - if host == "macos": + if host in ["macos", "macos-arm64"]: return f"/Library/Frameworks/Python.Framework/Versions/{python}/bin/python3" raise ValueError(f"no match found for python {python} on {host}") -def get_display_name(base: str, host: str, version: str, python: str) -> str: +def get_display_name(base: str, host: str, **kwargs) -> str: """Get the display name of a variant.""" - if version not in ["rapid", "latest"]: - version = f"v{version}" - if not python.startswith("pypy"): - python = f"py{python}" - return f"{base} {HOSTS[host].display_name} {version} {python}" + display_name = f"{base} {HOSTS[host].display_name}" + for key, value in kwargs.items(): + name = value + if key == "version": + if value not in ["rapid", "latest"]: + name = f"v{value}" + elif key == "python": + if not value.startswith("pypy"): + name = f"py{value}" + elif key.lower() in DISPLAY_LOOKUP: + name = DISPLAY_LOOKUP[key.lower()][value] + else: + raise ValueError(f"Missing display handling for {key}") + display_name = f"{display_name} {name}" + return display_name def zip_cycle(*iterables, empty_default=None): @@ -115,6 +141,15 @@ def zip_cycle(*iterables, empty_default=None): yield tuple(next(i, empty_default) for i in cycles) +def generate_yaml(tasks=None, variants=None): + """Generate the yaml for a given set of tasks and variants.""" + project = EvgProject(tasks=tasks, buildvariants=variants) + out = ShrubService.generate_yaml(project) + # Dedent by two spaces to match what we use in config.yml + lines = [line[2:] for line in out.splitlines()] + print("\n".join(lines)) # noqa: T201 + + ############## # Variants ############## @@ -159,9 +194,63 @@ def create_ocsp_variants() -> list[BuildVariant]: return variants +def create_server_variants() -> list[BuildVariant]: + variants = [] + + # Run the full matrix on linux with min and max CPython, and latest pypy. + host = "rhel8" + for python, (auth, ssl) in product([*MIN_MAX_PYTHON, PYPYS[-1]], AUTH_SSLS): + display_name = f"Test {host}" + expansions = dict(AUTH=auth, SSL=ssl, COVERAGE="coverage") + display_name = get_display_name("Test", host, python=python, **expansions) + variant = create_variant( + [f".{t}" for t in TOPOLOGIES], + display_name, + python=python, + host=host, + tags=["coverage_tag"], + expansions=expansions, + ) + variants.append(variant) + + # Test the rest of the pythons on linux. + for python, (auth, ssl), topology in zip_cycle( + CPYTHONS[1:-1] + PYPYS[:-1], AUTH_SSLS, TOPOLOGIES + ): + display_name = f"Test {host}" + expansions = dict(AUTH=auth, SSL=ssl) + display_name = get_display_name("Test", host, python=python, **expansions) + variant = create_variant( + [f".{topology}"], + display_name, + python=python, + host=host, + expansions=expansions, + ) + variants.append(variant) + + # Test a subset on each of the other platforms. + for host in ("macos", "macos-arm64", "win64", "win32"): + for (python, (auth, ssl), topology), sync in product( + zip_cycle(MIN_MAX_PYTHON, AUTH_SSLS, TOPOLOGIES), SYNCS + ): + test_suite = "default" if sync == "sync" else "default_async" + expansions = dict(AUTH=auth, SSL=ssl, TEST_SUITES=test_suite) + display_name = get_display_name("Test", host, python=python, **expansions) + variant = create_variant( + [f".{topology}"], + display_name, + python=python, + host=host, + expansions=expansions, + ) + variants.append(variant) + + return variants + + ################## # Generate Config ################## -project = EvgProject(tasks=None, buildvariants=create_ocsp_variants()) -print(ShrubService.generate_yaml(project)) # noqa: T201 +generate_yaml(variants=create_server_variants()) From 3855effbd844a4c48ca2d13f651ce6dd908b14a3 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 15:16:42 -0400 Subject: [PATCH 18/35] PYTHON-4842 - Convert test.test_create_entities to async (#1919) --- test/asynchronous/test_create_entities.py | 128 ++++++++++++++++++++++ test/test_create_entities.py | 2 + tools/synchro.py | 1 + 3 files changed, 131 insertions(+) create mode 100644 test/asynchronous/test_create_entities.py diff --git a/test/asynchronous/test_create_entities.py b/test/asynchronous/test_create_entities.py new file mode 100644 index 0000000000..cb2ec63f4c --- /dev/null +++ b/test/asynchronous/test_create_entities.py @@ -0,0 +1,128 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest +from test.asynchronous.unified_format import UnifiedSpecTestMixinV1 + +_IS_SYNC = False + + +class TestCreateEntities(AsyncIntegrationTest): + async def test_store_events_as_entities(self): + self.scenario_runner = UnifiedSpecTestMixinV1() + spec = { + "description": "blank", + "schemaVersion": "1.2", + "createEntities": [ + { + "client": { + "id": "client0", + "storeEventsAsEntities": [ + { + "id": "events1", + "events": [ + "PoolCreatedEvent", + ], + } + ], + } + }, + ], + "tests": [{"description": "foo", "operations": []}], + } + self.scenario_runner.TEST_SPEC = spec + await self.scenario_runner.asyncSetUp() + await self.scenario_runner.run_scenario(spec["tests"][0]) + await self.scenario_runner.entity_map["client0"].close() + final_entity_map = self.scenario_runner.entity_map + self.assertIn("events1", final_entity_map) + self.assertGreater(len(final_entity_map["events1"]), 0) + for event in final_entity_map["events1"]: + self.assertIn("PoolCreatedEvent", event["name"]) + + async def test_store_all_others_as_entities(self): + self.scenario_runner = UnifiedSpecTestMixinV1() + spec = { + "description": "Find", + "schemaVersion": "1.2", + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": {"retryReads": True}, + } + }, + {"database": {"id": "database0", "client": "client0", "databaseName": "dat"}}, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "dat", + } + }, + ], + "tests": [ + { + "description": "test loops", + "operations": [ + { + "name": "loop", + "object": "testRunner", + "arguments": { + "storeIterationsAsEntity": "iterations", + "storeSuccessesAsEntity": "successes", + "storeFailuresAsEntity": "failures", + "storeErrorsAsEntity": "errors", + "numIterations": 5, + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": {"document": {"_id": 1, "x": 44}}, + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": {"document": {"_id": 2, "x": 44}}, + }, + ], + }, + } + ], + } + ], + } + + await self.client.dat.dat.delete_many({}) + self.scenario_runner.TEST_SPEC = spec + await self.scenario_runner.asyncSetUp() + await self.scenario_runner.run_scenario(spec["tests"][0]) + await self.scenario_runner.entity_map["client0"].close() + entity_map = self.scenario_runner.entity_map + self.assertEqual(len(entity_map["errors"]), 4) + for error in entity_map["errors"]: + self.assertEqual(error["type"], "DuplicateKeyError") + self.assertEqual(entity_map["failures"], []) + self.assertEqual(entity_map["successes"], 2) + self.assertEqual(entity_map["iterations"], 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_create_entities.py b/test/test_create_entities.py index b7965d4a1d..ad75fe5702 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -21,6 +21,8 @@ from test import IntegrationTest from test.unified_format import UnifiedSpecTestMixinV1 +_IS_SYNC = True + class TestCreateEntities(IntegrationTest): def test_store_events_as_entities(self): diff --git a/tools/synchro.py b/tools/synchro.py index 25f506ed5a..2123a66616 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -199,6 +199,7 @@ def async_only_test(f: str) -> bool: "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", + "test_create_entities.py", "test_crud_unified.py", "test_cursor.py", "test_database.py", From fa263dc87dfe13f1c2de14ab86c67871ed3b24fb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 15:48:05 -0400 Subject: [PATCH 19/35] PYTHON-4847 - Convert test.test_collection_management.py to async (#1916) --- .../test_collection_management.py | 41 +++++++++++++++++++ test/asynchronous/unified_format.py | 2 +- test/test_collection_management.py | 12 +++++- test/unified_format.py | 2 +- tools/synchro.py | 1 + 5 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 test/asynchronous/test_collection_management.py diff --git a/test/asynchronous/test_collection_management.py b/test/asynchronous/test_collection_management.py new file mode 100644 index 0000000000..c0edf91581 --- /dev/null +++ b/test/asynchronous/test_collection_management.py @@ -0,0 +1,41 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the collection management unified spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "collection_management") +else: + _TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "collection_management" + ) + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 42bda59cb2..8f32ac4a2e 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -773,7 +773,7 @@ async def _databaseOperation_listCollections(self, target, *args, **kwargs): if "batch_size" in kwargs: kwargs["cursor"] = {"batchSize": kwargs.pop("batch_size")} cursor = await target.list_collections(*args, **kwargs) - return list(cursor) + return await cursor.to_list() async def _databaseOperation_createCollection(self, target, *args, **kwargs): # PYTHON-1936 Ignore the listCollections event from create_collection. diff --git a/test/test_collection_management.py b/test/test_collection_management.py index 0eacde1302..063c20df8f 100644 --- a/test/test_collection_management.py +++ b/test/test_collection_management.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,11 +24,18 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "collection_management") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "collection_management") +else: + _TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "collection_management" + ) # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/unified_format.py b/test/unified_format.py index 13ab0af69b..be7fc1f8ad 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -769,7 +769,7 @@ def _databaseOperation_listCollections(self, target, *args, **kwargs): if "batch_size" in kwargs: kwargs["cursor"] = {"batchSize": kwargs.pop("batch_size")} cursor = target.list_collections(*args, **kwargs) - return list(cursor) + return cursor.to_list() def _databaseOperation_createCollection(self, target, *args, **kwargs): # PYTHON-1936 Ignore the listCollections event from create_collection. diff --git a/tools/synchro.py b/tools/synchro.py index 2123a66616..0a7109c6d4 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -192,6 +192,7 @@ def async_only_test(f: str) -> bool: "test_client_context.py", "test_collation.py", "test_collection.py", + "test_collection_management.py", "test_command_logging.py", "test_command_logging.py", "test_command_monitoring.py", From 8034baec90043c1d3cf4dc17a5481a559743c524 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Tue, 15 Oct 2024 18:45:49 -0400 Subject: [PATCH 20/35] PYTHON-4834 Add __repr__ to IndexModel, SearchIndexModel (#1909) --- doc/changelog.rst | 2 ++ pymongo/operations.py | 13 +++++++ test/test_operations.py | 80 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 test/test_operations.py diff --git a/doc/changelog.rst b/doc/changelog.rst index 44d6fc9a57..3935fa3492 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -17,6 +17,8 @@ PyMongo 4.11 brings a number of changes including: - :attr:`~pymongo.asynchronous.mongo_client.AsyncMongoClient.address` and :attr:`~pymongo.mongo_client.MongoClient.address` now correctly block when called on unconnected clients until either connection succeeds or a server selection timeout error is raised. +- Added :func:`repr` support to :class:`pymongo.operations.IndexModel`. +- Added :func:`repr` support to :class:`pymongo.operations.SearchIndexModel`. Issues Resolved ............... diff --git a/pymongo/operations.py b/pymongo/operations.py index d2e1feba69..384ffc94be 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -773,6 +773,13 @@ def document(self) -> dict[str, Any]: """ return self.__document + def __repr__(self) -> str: + return "{}({}{})".format( + self.__class__.__name__, + self.document["key"], + "".join([f", {key}={value!r}" for key, value in self.document.items() if key != "key"]), + ) + class SearchIndexModel: """Represents a search index to create.""" @@ -812,3 +819,9 @@ def __init__( def document(self) -> Mapping[str, Any]: """The document for this index.""" return self.__document + + def __repr__(self) -> str: + return "{}({})".format( + self.__class__.__name__, + ", ".join([f"{key}={value!r}" for key, value in self.document.items()]), + ) diff --git a/test/test_operations.py b/test/test_operations.py new file mode 100644 index 0000000000..3ee6677735 --- /dev/null +++ b/test/test_operations.py @@ -0,0 +1,80 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the operations module.""" +from __future__ import annotations + +from test import UnitTest, unittest + +from pymongo import ASCENDING, DESCENDING +from pymongo.collation import Collation +from pymongo.errors import OperationFailure +from pymongo.operations import IndexModel, SearchIndexModel + + +class TestOperationsBase(UnitTest): + """Base class for testing operations module.""" + + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + +class TestIndexModel(TestOperationsBase): + """Test IndexModel features.""" + + def test_repr(self): + # Based on examples in test_collection.py + self.assertRepr(IndexModel("hello")) + self.assertRepr(IndexModel([("hello", DESCENDING), ("world", ASCENDING)])) + self.assertRepr( + IndexModel([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world") + ) + # Test all the kwargs + self.assertRepr(IndexModel("name", name="name")) + self.assertRepr(IndexModel("unique", unique=False)) + self.assertRepr(IndexModel("background", background=True)) + self.assertRepr(IndexModel("sparse", sparse=True)) + self.assertRepr(IndexModel("bucketSize", bucketSize=1)) + self.assertRepr(IndexModel("min", min=1)) + self.assertRepr(IndexModel("max", max=1)) + self.assertRepr(IndexModel("expireAfterSeconds", expireAfterSeconds=1)) + self.assertRepr( + IndexModel("partialFilterExpression", partialFilterExpression={"hello": "world"}) + ) + self.assertRepr(IndexModel("collation", collation=Collation(locale="en_US"))) + self.assertRepr(IndexModel("wildcardProjection", wildcardProjection={"$**": 1})) + self.assertRepr(IndexModel("hidden", hidden=False)) + # Test string literal + self.assertEqual(repr(IndexModel("hello")), "IndexModel({'hello': 1}, name='hello_1')") + self.assertEqual( + repr(IndexModel({"hello": 1, "world": -1})), + "IndexModel({'hello': 1, 'world': -1}, name='hello_1_world_-1')", + ) + + +class TestSearchIndexModel(TestOperationsBase): + """Test SearchIndexModel features.""" + + def test_repr(self): + self.assertRepr(SearchIndexModel({"hello": "hello"}, key=1)) + self.assertEqual( + repr(SearchIndexModel({"hello": "hello"}, key=1)), + "SearchIndexModel(definition={'hello': 'hello'}, key=1)", + ) + + +if __name__ == "__main__": + unittest.main() From 463518bf8136264fbed34e3c3ddbc0a34d109156 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Wed, 16 Oct 2024 11:02:57 -0400 Subject: [PATCH 21/35] PYTHON-4765 Resync server-selection spec (#1935) --- .../operation-id.json | 4 +- .../server_selection_logging/replica-set.json | 2 +- test/server_selection_logging/sharded.json | 2 +- test/server_selection_logging/standalone.json | 930 +----------------- 4 files changed, 6 insertions(+), 932 deletions(-) diff --git a/test/server_selection_logging/operation-id.json b/test/server_selection_logging/operation-id.json index ccc2623166..72ebff60d8 100644 --- a/test/server_selection_logging/operation-id.json +++ b/test/server_selection_logging/operation-id.json @@ -197,7 +197,7 @@ } }, { - "level": "debug", + "level": "info", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", @@ -383,7 +383,7 @@ } }, { - "level": "debug", + "level": "info", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", diff --git a/test/server_selection_logging/replica-set.json b/test/server_selection_logging/replica-set.json index 830b1ea51a..5eba784bf2 100644 --- a/test/server_selection_logging/replica-set.json +++ b/test/server_selection_logging/replica-set.json @@ -184,7 +184,7 @@ } }, { - "level": "debug", + "level": "info", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", diff --git a/test/server_selection_logging/sharded.json b/test/server_selection_logging/sharded.json index 346c050f9e..d42fba9100 100644 --- a/test/server_selection_logging/sharded.json +++ b/test/server_selection_logging/sharded.json @@ -193,7 +193,7 @@ } }, { - "level": "debug", + "level": "info", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", diff --git a/test/server_selection_logging/standalone.json b/test/server_selection_logging/standalone.json index 3152d0bbf3..3b3eddd841 100644 --- a/test/server_selection_logging/standalone.json +++ b/test/server_selection_logging/standalone.json @@ -47,29 +47,9 @@ } } ], - "initialData": [ - { - "collectionName": "server-selection", - "databaseName": "logging-tests", - "documents": [ - { - "_id": 1, - "x": 11 - }, - { - "_id": 2, - "x": 22 - }, - { - "_id": 3, - "x": 33 - } - ] - } - ], "tests": [ { - "description": "A successful insert operation", + "description": "A successful operation", "operations": [ { "name": "waitForEvent", @@ -211,7 +191,7 @@ } }, { - "level": "debug", + "level": "info", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", @@ -250,912 +230,6 @@ ] } ] - }, - { - "description": "A successful find operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "findOne", - "object": "collection", - "arguments": { - "filter": { - "x": 1 - } - } - } - - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "find", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "find", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful findAndModify operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "findOneAndReplace", - "object": "collection", - "arguments": { - "filter": { - "x": 1 - }, - "replacement": { - "x": 11 - } - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "findAndModify", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "findAndModify", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful find and getMore operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "find", - "object": "collection", - "arguments": { - "batchSize": 3 - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "find", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "find", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "getMore", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "getMore", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful aggregate operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "aggregate", - "object": "collection", - "arguments": { - "pipeline": [ - { - "$match": { - "_id": { - "$gt": 1 - } - } - } - ] - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "aggregate", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "aggregate", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful count operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "countDocuments", - "object": "collection", - "arguments": { - "filter": {} - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "count", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "count", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful distinct operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "distinct", - "object": "collection", - "arguments": { - "fieldName": "x", - "filter": {} - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "distinct", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "distinct", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "Successful collection management operations", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "createCollection", - "object": "database", - "arguments": { - "collection": "foo" - } - }, - { - "name": "listCollections", - "object": "database" - }, - { - "name": "dropCollection", - "object": "database", - "arguments": { - "collection": "foo" - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "create", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "create", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "listCollections", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "listCollections", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "drop", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "drop", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "Successful index operations", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "createIndex", - "object": "collection", - "arguments": { - "keys": { - "x": 1 - }, - "name": "x_1" - } - }, - { - "name": "listIndexes", - "object": "collection" - }, - { - "name": "dropIndex", - "object": "collection", - "arguments": { - "name": "x_1" - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "createIndexes", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "createIndexes", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "listIndexes", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "listIndexes", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "dropIndexes", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "dropIndexes", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful update operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "updateOne", - "object": "collection", - "arguments": { - "filter": { - "x": 1 - }, - "update": { - "$inc": { - "x": 1 - } - } - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "update", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "update", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] - }, - { - "description": "A successful delete operation", - "operations": [ - { - "name": "waitForEvent", - "object": "testRunner", - "arguments": { - "client": "client", - "event": { - "topologyDescriptionChangedEvent": {} - }, - "count": 2 - } - }, - { - "name": "deleteOne", - "object": "collection", - "arguments": { - "filter": { - "x": 1 - } - } - } - ], - "expectLogMessages": [ - { - "client": "client", - "messages": [ - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection started", - "selector": { - "$$exists": true - }, - "operation": "delete", - "topologyDescription": { - "$$exists": true - } - } - }, - { - "level": "debug", - "component": "serverSelection", - "data": { - "message": "Server selection succeeded", - "selector": { - "$$exists": true - }, - "operation": "delete", - "topologyDescription": { - "$$exists": true - }, - "serverHost": { - "$$type": "string" - }, - "serverPort": { - "$$type": [ - "int", - "long" - ] - } - } - } - ] - } - ] } ] } From d1375d4178822c376ce3beb0f5987dd7894a03aa Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 16 Oct 2024 13:41:35 -0500 Subject: [PATCH 22/35] PYTHON-4865 Skip test_write_concern_failure tests temporarily (#1936) --- test/asynchronous/test_bulk.py | 6 ++++++ test/test_bulk.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 42a3311072..c9ff167b43 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -971,6 +971,9 @@ async def cause_wtimeout(self, requests, ordered): @async_client_context.require_replica_set @async_client_context.require_secondaries_count(1) async def test_write_concern_failure_ordered(self): + self.skipTest("Skipping until PYTHON-4865 is resolved.") + details = None + # Ensure we don't raise on wnote. coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) result = await coll_ww.bulk_write([DeleteOne({"something": "that does no exist"})]) @@ -1051,6 +1054,9 @@ async def test_write_concern_failure_ordered(self): @async_client_context.require_replica_set @async_client_context.require_secondaries_count(1) async def test_write_concern_failure_unordered(self): + self.skipTest("Skipping until PYTHON-4865 is resolved.") + details = None + # Ensure we don't raise on wnote. coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) result = await coll_ww.bulk_write( diff --git a/test/test_bulk.py b/test/test_bulk.py index 64fd48e8cd..ea2b803804 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -969,6 +969,9 @@ def cause_wtimeout(self, requests, ordered): @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_write_concern_failure_ordered(self): + self.skipTest("Skipping until PYTHON-4865 is resolved.") + details = None + # Ensure we don't raise on wnote. coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) result = coll_ww.bulk_write([DeleteOne({"something": "that does no exist"})]) @@ -1049,6 +1052,9 @@ def test_write_concern_failure_ordered(self): @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_write_concern_failure_unordered(self): + self.skipTest("Skipping until PYTHON-4865 is resolved.") + details = None + # Ensure we don't raise on wnote. coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) result = coll_ww.bulk_write([DeleteOne({"something": "that does no exist"})], ordered=False) From 29064f5b1d85cbea872a6c37023e3d5fa25b9a3d Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 16 Oct 2024 12:15:48 -0700 Subject: [PATCH 23/35] PYTHON-4873 Remove bson-stdint-win32.h from THIRD-PARTY-NOTICES (#1937) --- THIRD-PARTY-NOTICES | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 0b9fc738ed..55b8ff7078 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -38,36 +38,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -2) License Notice for bson-stdint-win32.h ------------------------------------------ - -ISO C9x compliant stdint.h for Microsoft Visual Studio -Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124 - - Copyright (c) 2006-2013 Alexander Chemeris - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the product nor the names of its contributors may - be used to endorse or promote products derived from this software - without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED -WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO -EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR -OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From 6f4258c1cdb95f6fe624a66760a66423048b6884 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 16 Oct 2024 16:41:14 -0500 Subject: [PATCH 24/35] PYTHON-4576 Allow update to supply sort option (#1881) --- doc/changelog.rst | 4 + pymongo/asynchronous/bulk.py | 13 + pymongo/asynchronous/client_bulk.py | 9 + pymongo/asynchronous/collection.py | 25 + pymongo/operations.py | 44 +- pymongo/synchronous/bulk.py | 13 + pymongo/synchronous/client_bulk.py | 9 + pymongo/synchronous/collection.py | 25 + .../aggregate-write-readPreference.json | 69 --- .../unified/bulkWrite-replaceOne-sort.json | 239 ++++++++ .../unified/bulkWrite-updateOne-sort.json | 255 +++++++++ .../client-bulkWrite-partialResults.json | 540 ++++++++++++++++++ .../client-bulkWrite-replaceOne-sort.json | 162 ++++++ .../client-bulkWrite-updateOne-sort.json | 166 ++++++ .../db-aggregate-write-readPreference.json | 51 -- test/crud/unified/replaceOne-sort.json | 232 ++++++++ test/crud/unified/updateOne-sort.json | 240 ++++++++ test/utils.py | 4 - 18 files changed, 1967 insertions(+), 133 deletions(-) create mode 100644 test/crud/unified/bulkWrite-replaceOne-sort.json create mode 100644 test/crud/unified/bulkWrite-updateOne-sort.json create mode 100644 test/crud/unified/client-bulkWrite-partialResults.json create mode 100644 test/crud/unified/client-bulkWrite-replaceOne-sort.json create mode 100644 test/crud/unified/client-bulkWrite-updateOne-sort.json create mode 100644 test/crud/unified/replaceOne-sort.json create mode 100644 test/crud/unified/updateOne-sort.json diff --git a/doc/changelog.rst b/doc/changelog.rst index 3935fa3492..4c1955d19d 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -19,6 +19,10 @@ PyMongo 4.11 brings a number of changes including: until either connection succeeds or a server selection timeout error is raised. - Added :func:`repr` support to :class:`pymongo.operations.IndexModel`. - Added :func:`repr` support to :class:`pymongo.operations.SearchIndexModel`. +- Added ``sort`` parameter to + :meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.replace_one`, + :class:`~pymongo.operations.UpdateOne`, and + :class:`~pymongo.operations.UpdateMany`, Issues Resolved ............... diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 9d33a990ed..e6cfe5b36e 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -109,6 +109,7 @@ def __init__( self.uses_array_filters = False self.uses_hint_update = False self.uses_hint_delete = False + self.uses_sort = False self.is_retryable = True self.retrying = False self.started_retryable_write = False @@ -144,6 +145,7 @@ def add_update( collation: Optional[Mapping[str, Any]] = None, array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) @@ -159,6 +161,9 @@ def add_update( if hint is not None: self.uses_hint_update = True cmd["hint"] = hint + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False @@ -171,6 +176,7 @@ def add_replace( upsert: bool = False, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) @@ -181,6 +187,9 @@ def add_replace( if hint is not None: self.uses_hint_update = True cmd["hint"] = hint + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) def add_delete( @@ -699,6 +708,10 @@ async def execute_no_results( raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) + if unack and self.uses_sort and conn.max_wire_version < 25: + raise ConfigurationError( + "Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands." + ) # Cannot have both unacknowledged writes and bypass document validation. if self.bypass_doc_val: raise OperationFailure( diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index dc800c9549..96571c21eb 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -118,6 +118,7 @@ def __init__( self.uses_array_filters = False self.uses_hint_update = False self.uses_hint_delete = False + self.uses_sort = False self.is_retryable = self.client.options.retry_writes self.retrying = False @@ -148,6 +149,7 @@ def add_update( collation: Optional[Mapping[str, Any]] = None, array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) @@ -169,6 +171,9 @@ def add_update( if collation is not None: self.uses_collation = True cmd["collation"] = collation + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False @@ -184,6 +189,7 @@ def add_replace( upsert: Optional[bool] = None, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) @@ -202,6 +208,9 @@ def add_replace( if collation is not None: self.uses_collation = True cmd["collation"] = collation + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort self.ops.append(("replace", cmd)) self.namespaces.append(namespace) self.total_ops += 1 diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 4ddcbab4d2..9b73423627 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -993,6 +993,7 @@ async def _update( session: Optional[AsyncClientSession] = None, retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: """Internal update / replace helper.""" @@ -1024,6 +1025,14 @@ async def _update( if not isinstance(hint, str): hint = helpers_shared._index_document(hint) update_doc["hint"] = hint + if sort is not None: + if not acknowledged and conn.max_wire_version < 25: + raise ConfigurationError( + "Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands." + ) + common.validate_is_mapping("sort", sort) + update_doc["sort"] = sort + command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} if let is not None: common.validate_is_mapping("let", let) @@ -1079,6 +1088,7 @@ async def _update_retryable( hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: """Internal update / replace helper.""" @@ -1102,6 +1112,7 @@ async def _update( session=session, retryable_write=retryable_write, let=let, + sort=sort, comment=comment, ) @@ -1122,6 +1133,7 @@ async def replace_one( hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: """Replace a single document matching the filter. @@ -1176,8 +1188,13 @@ async def replace_one( aggregate expression context (e.g. "$$var"). :param comment: A user-provided comment to attach to this command. + :param sort: Specify which document the operation updates if the query matches + multiple documents. The first document matched by the sort order will be updated. + This option is only supported on MongoDB 8.0 and above. :return: - An instance of :class:`~pymongo.results.UpdateResult`. + .. versionchanged:: 4.11 + Added ``sort`` parameter. .. versionchanged:: 4.1 Added ``let`` parameter. Added ``comment`` parameter. @@ -1209,6 +1226,7 @@ async def replace_one( hint=hint, session=session, let=let, + sort=sort, comment=comment, ), write_concern.acknowledged, @@ -1225,6 +1243,7 @@ async def update_one( hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: """Update a single document matching the filter. @@ -1283,11 +1302,16 @@ async def update_one( constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). + :param sort: Specify which document the operation updates if the query matches + multiple documents. The first document matched by the sort order will be updated. + This option is only supported on MongoDB 8.0 and above. :param comment: A user-provided comment to attach to this command. :return: - An instance of :class:`~pymongo.results.UpdateResult`. + .. versionchanged:: 4.11 + Added ``sort`` parameter. .. versionchanged:: 4.1 Added ``let`` parameter. Added ``comment`` parameter. @@ -1322,6 +1346,7 @@ async def update_one( hint=hint, session=session, let=let, + sort=sort, comment=comment, ), write_concern.acknowledged, diff --git a/pymongo/operations.py b/pymongo/operations.py index 384ffc94be..8905048c4e 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -325,6 +325,7 @@ class ReplaceOne(Generic[_DocumentType]): "_collation", "_hint", "_namespace", + "_sort", ) def __init__( @@ -335,6 +336,7 @@ def __init__( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, namespace: Optional[str] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create a ReplaceOne instance. @@ -353,8 +355,12 @@ def __init__( :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. + :param sort: Specify which document the operation updates if the query matches + multiple documents. The first document matched by the sort order will be updated. :param namespace: (optional) The namespace in which to replace a document. + .. versionchanged:: 4.10 + Added ``sort`` option. .. versionchanged:: 4.9 Added the `namespace` option to support `MongoClient.bulk_write`. .. versionchanged:: 3.11 @@ -371,6 +377,7 @@ def __init__( else: self._hint = hint + self._sort = sort self._filter = filter self._doc = replacement self._upsert = upsert @@ -385,6 +392,7 @@ def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: self._upsert, collation=validate_collation_or_none(self._collation), hint=self._hint, + sort=self._sort, ) def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: @@ -400,6 +408,7 @@ def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: self._upsert, collation=validate_collation_or_none(self._collation), hint=self._hint, + sort=self._sort, ) def __eq__(self, other: Any) -> bool: @@ -411,13 +420,15 @@ def __eq__(self, other: Any) -> bool: other._collation, other._hint, other._namespace, + other._sort, ) == ( self._filter, self._doc, self._upsert, self._collation, - other._hint, + self._hint, self._namespace, + self._sort, ) return NotImplemented @@ -426,7 +437,7 @@ def __ne__(self, other: Any) -> bool: def __repr__(self) -> str: if self._namespace: - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( self.__class__.__name__, self._filter, self._doc, @@ -434,14 +445,16 @@ def __repr__(self) -> str: self._collation, self._hint, self._namespace, + self._sort, ) - return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( self.__class__.__name__, self._filter, self._doc, self._upsert, self._collation, self._hint, + self._sort, ) @@ -456,6 +469,7 @@ class _UpdateOp: "_array_filters", "_hint", "_namespace", + "_sort", ) def __init__( @@ -467,6 +481,7 @@ def __init__( array_filters: Optional[list[Mapping[str, Any]]], hint: Optional[_IndexKeyHint], namespace: Optional[str], + sort: Optional[Mapping[str, Any]], ): if filter is not None: validate_is_mapping("filter", filter) @@ -478,13 +493,13 @@ def __init__( self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) else: self._hint = hint - self._filter = filter self._doc = doc self._upsert = upsert self._collation = collation self._array_filters = array_filters self._namespace = namespace + self._sort = sort def __eq__(self, other: object) -> bool: if isinstance(other, type(self)): @@ -496,6 +511,7 @@ def __eq__(self, other: object) -> bool: other._array_filters, other._hint, other._namespace, + other._sort, ) == ( self._filter, self._doc, @@ -504,6 +520,7 @@ def __eq__(self, other: object) -> bool: self._array_filters, self._hint, self._namespace, + self._sort, ) return NotImplemented @@ -512,7 +529,7 @@ def __ne__(self, other: Any) -> bool: def __repr__(self) -> str: if self._namespace: - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( self.__class__.__name__, self._filter, self._doc, @@ -521,8 +538,9 @@ def __repr__(self) -> str: self._array_filters, self._hint, self._namespace, + self._sort, ) - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( self.__class__.__name__, self._filter, self._doc, @@ -530,6 +548,7 @@ def __repr__(self) -> str: self._collation, self._array_filters, self._hint, + self._sort, ) @@ -547,6 +566,7 @@ def __init__( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, namespace: Optional[str] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Represents an update_one operation. @@ -567,8 +587,12 @@ def __init__( :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. - :param namespace: (optional) The namespace in which to update a document. + :param namespace: The namespace in which to update a document. + :param sort: Specify which document the operation updates if the query matches + multiple documents. The first document matched by the sort order will be updated. + .. versionchanged:: 4.10 + Added ``sort`` option. .. versionchanged:: 4.9 Added the `namespace` option to support `MongoClient.bulk_write`. .. versionchanged:: 3.11 @@ -580,7 +604,7 @@ def __init__( .. versionchanged:: 3.5 Added the `collation` option. """ - super().__init__(filter, update, upsert, collation, array_filters, hint, namespace) + super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, sort) def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" @@ -592,6 +616,7 @@ def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: collation=validate_collation_or_none(self._collation), array_filters=self._array_filters, hint=self._hint, + sort=self._sort, ) def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: @@ -609,6 +634,7 @@ def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: collation=validate_collation_or_none(self._collation), array_filters=self._array_filters, hint=self._hint, + sort=self._sort, ) @@ -659,7 +685,7 @@ def __init__( .. versionchanged:: 3.5 Added the `collation` option. """ - super().__init__(filter, update, upsert, collation, array_filters, hint, namespace) + super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, None) def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index c658157ea1..7fb29a977f 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -109,6 +109,7 @@ def __init__( self.uses_array_filters = False self.uses_hint_update = False self.uses_hint_delete = False + self.uses_sort = False self.is_retryable = True self.retrying = False self.started_retryable_write = False @@ -144,6 +145,7 @@ def add_update( collation: Optional[Mapping[str, Any]] = None, array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) @@ -159,6 +161,9 @@ def add_update( if hint is not None: self.uses_hint_update = True cmd["hint"] = hint + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False @@ -171,6 +176,7 @@ def add_replace( upsert: bool = False, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) @@ -181,6 +187,9 @@ def add_replace( if hint is not None: self.uses_hint_update = True cmd["hint"] = hint + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) def add_delete( @@ -697,6 +706,10 @@ def execute_no_results( raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) + if unack and self.uses_sort and conn.max_wire_version < 25: + raise ConfigurationError( + "Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands." + ) # Cannot have both unacknowledged writes and bypass document validation. if self.bypass_doc_val: raise OperationFailure( diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index f41f0203f2..2c38b1d76c 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -118,6 +118,7 @@ def __init__( self.uses_array_filters = False self.uses_hint_update = False self.uses_hint_delete = False + self.uses_sort = False self.is_retryable = self.client.options.retry_writes self.retrying = False @@ -148,6 +149,7 @@ def add_update( collation: Optional[Mapping[str, Any]] = None, array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) @@ -169,6 +171,9 @@ def add_update( if collation is not None: self.uses_collation = True cmd["collation"] = collation + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False @@ -184,6 +189,7 @@ def add_replace( upsert: Optional[bool] = None, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, + sort: Optional[Mapping[str, Any]] = None, ) -> None: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) @@ -202,6 +208,9 @@ def add_replace( if collation is not None: self.uses_collation = True cmd["collation"] = collation + if sort is not None: + self.uses_sort = True + cmd["sort"] = sort self.ops.append(("replace", cmd)) self.namespaces.append(namespace) self.total_ops += 1 diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 6fd2ac82dd..6edfddc9a9 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -992,6 +992,7 @@ def _update( session: Optional[ClientSession] = None, retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: """Internal update / replace helper.""" @@ -1023,6 +1024,14 @@ def _update( if not isinstance(hint, str): hint = helpers_shared._index_document(hint) update_doc["hint"] = hint + if sort is not None: + if not acknowledged and conn.max_wire_version < 25: + raise ConfigurationError( + "Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands." + ) + common.validate_is_mapping("sort", sort) + update_doc["sort"] = sort + command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} if let is not None: common.validate_is_mapping("let", let) @@ -1078,6 +1087,7 @@ def _update_retryable( hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: """Internal update / replace helper.""" @@ -1101,6 +1111,7 @@ def _update( session=session, retryable_write=retryable_write, let=let, + sort=sort, comment=comment, ) @@ -1121,6 +1132,7 @@ def replace_one( hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: """Replace a single document matching the filter. @@ -1175,8 +1187,13 @@ def replace_one( aggregate expression context (e.g. "$$var"). :param comment: A user-provided comment to attach to this command. + :param sort: Specify which document the operation updates if the query matches + multiple documents. The first document matched by the sort order will be updated. + This option is only supported on MongoDB 8.0 and above. :return: - An instance of :class:`~pymongo.results.UpdateResult`. + .. versionchanged:: 4.11 + Added ``sort`` parameter. .. versionchanged:: 4.1 Added ``let`` parameter. Added ``comment`` parameter. @@ -1208,6 +1225,7 @@ def replace_one( hint=hint, session=session, let=let, + sort=sort, comment=comment, ), write_concern.acknowledged, @@ -1224,6 +1242,7 @@ def update_one( hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, let: Optional[Mapping[str, Any]] = None, + sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: """Update a single document matching the filter. @@ -1282,11 +1301,16 @@ def update_one( constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). + :param sort: Specify which document the operation updates if the query matches + multiple documents. The first document matched by the sort order will be updated. + This option is only supported on MongoDB 8.0 and above. :param comment: A user-provided comment to attach to this command. :return: - An instance of :class:`~pymongo.results.UpdateResult`. + .. versionchanged:: 4.11 + Added ``sort`` parameter. .. versionchanged:: 4.1 Added ``let`` parameter. Added ``comment`` parameter. @@ -1321,6 +1345,7 @@ def update_one( hint=hint, session=session, let=let, + sort=sort, comment=comment, ), write_concern.acknowledged, diff --git a/test/crud/unified/aggregate-write-readPreference.json b/test/crud/unified/aggregate-write-readPreference.json index bc887e83cb..c1fa3b4574 100644 --- a/test/crud/unified/aggregate-write-readPreference.json +++ b/test/crud/unified/aggregate-write-readPreference.json @@ -78,11 +78,6 @@ "x": 33 } ] - }, - { - "collectionName": "coll1", - "databaseName": "db0", - "documents": [] } ], "tests": [ @@ -159,22 +154,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll1", - "databaseName": "db0", - "documents": [ - { - "_id": 2, - "x": 22 - }, - { - "_id": 3, - "x": 33 - } - ] - } ] }, { @@ -250,22 +229,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll1", - "databaseName": "db0", - "documents": [ - { - "_id": 2, - "x": 22 - }, - { - "_id": 3, - "x": 33 - } - ] - } ] }, { @@ -344,22 +307,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll1", - "databaseName": "db0", - "documents": [ - { - "_id": 2, - "x": 22 - }, - { - "_id": 3, - "x": 33 - } - ] - } ] }, { @@ -438,22 +385,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll1", - "databaseName": "db0", - "documents": [ - { - "_id": 2, - "x": 22 - }, - { - "_id": 3, - "x": 33 - } - ] - } ] } ] diff --git a/test/crud/unified/bulkWrite-replaceOne-sort.json b/test/crud/unified/bulkWrite-replaceOne-sort.json new file mode 100644 index 0000000000..c0bd383514 --- /dev/null +++ b/test/crud/unified/bulkWrite-replaceOne-sort.json @@ -0,0 +1,239 @@ +{ + "description": "BulkWrite replaceOne-sort", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "tests": [ + { + "description": "BulkWrite replaceOne with sort option", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "replaceOne": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "replacement": { + "x": 1 + } + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "x": 1 + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + }, + { + "commandSucceededEvent": { + "reply": { + "ok": 1, + "n": 1 + }, + "commandName": "update" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 1 + } + ] + } + ] + }, + { + "description": "BulkWrite replaceOne with sort option unsupported (server-side error)", + "runOnRequirements": [ + { + "maxServerVersion": "7.99" + } + ], + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "replaceOne": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "replacement": { + "x": 1 + } + } + } + ] + }, + "expectError": { + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "x": 1 + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/bulkWrite-updateOne-sort.json b/test/crud/unified/bulkWrite-updateOne-sort.json new file mode 100644 index 0000000000..f78bd3bf3e --- /dev/null +++ b/test/crud/unified/bulkWrite-updateOne-sort.json @@ -0,0 +1,255 @@ +{ + "description": "BulkWrite updateOne-sort", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "tests": [ + { + "description": "BulkWrite updateOne with sort option", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "updateOne": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "update": [ + { + "$set": { + "x": 1 + } + } + ] + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": [ + { + "$set": { + "x": 1 + } + } + ], + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + }, + { + "commandSucceededEvent": { + "reply": { + "ok": 1, + "n": 1 + }, + "commandName": "update" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 1 + } + ] + } + ] + }, + { + "description": "BulkWrite updateOne with sort option unsupported (server-side error)", + "runOnRequirements": [ + { + "maxServerVersion": "7.99" + } + ], + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "updateOne": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "update": [ + { + "$set": { + "x": 1 + } + } + ] + } + } + ] + }, + "expectError": { + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": [ + { + "$set": { + "x": 1 + } + } + ], + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/client-bulkWrite-partialResults.json b/test/crud/unified/client-bulkWrite-partialResults.json new file mode 100644 index 0000000000..b35e94a2ea --- /dev/null +++ b/test/crud/unified/client-bulkWrite-partialResults.json @@ -0,0 +1,540 @@ +{ + "description": "client bulkWrite partial results", + "schemaVersion": "1.4", + "runOnRequirements": [ + { + "minServerVersion": "8.0", + "serverless": "forbid" + } + ], + "createEntities": [ + { + "client": { + "id": "client0" + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + } + ] + } + ], + "_yamlAnchors": { + "namespace": "crud-tests.coll0", + "newDocument": { + "_id": 2, + "x": 22 + } + }, + "tests": [ + { + "description": "partialResult is unset when first operation fails during an ordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": true, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "$$unsetOrMatches": { + "insertedCount": { + "$$exists": false + }, + "upsertedCount": { + "$$exists": false + }, + "matchedCount": { + "$$exists": false + }, + "modifiedCount": { + "$$exists": false + }, + "deletedCount": { + "$$exists": false + }, + "insertResults": { + "$$exists": false + }, + "updateResults": { + "$$exists": false + }, + "deleteResults": { + "$$exists": false + } + } + } + } + } + ] + }, + { + "description": "partialResult is unset when first operation fails during an ordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": true, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "$$unsetOrMatches": { + "insertedCount": { + "$$exists": false + }, + "upsertedCount": { + "$$exists": false + }, + "matchedCount": { + "$$exists": false + }, + "modifiedCount": { + "$$exists": false + }, + "deletedCount": { + "$$exists": false + }, + "insertResults": { + "$$exists": false + }, + "updateResults": { + "$$exists": false + }, + "deleteResults": { + "$$exists": false + } + } + } + } + } + ] + }, + { + "description": "partialResult is set when second operation fails during an ordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": true, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "0": { + "insertedId": 2 + } + }, + "updateResults": {}, + "deleteResults": {} + } + } + } + ] + }, + { + "description": "partialResult is set when second operation fails during an ordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": true, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } + } + } + } + ] + }, + { + "description": "partialResult is unset when all operations fail during an unordered bulk write", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": false + }, + "expectError": { + "expectResult": { + "$$unsetOrMatches": { + "insertedCount": { + "$$exists": false + }, + "upsertedCount": { + "$$exists": false + }, + "matchedCount": { + "$$exists": false + }, + "modifiedCount": { + "$$exists": false + }, + "deletedCount": { + "$$exists": false + }, + "insertResults": { + "$$exists": false + }, + "updateResults": { + "$$exists": false + }, + "deleteResults": { + "$$exists": false + } + } + } + } + } + ] + }, + { + "description": "partialResult is set when first operation fails during an unordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": false, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "1": { + "insertedId": 2 + } + }, + "updateResults": {}, + "deleteResults": {} + } + } + } + ] + }, + { + "description": "partialResult is set when first operation fails during an unordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "ordered": false, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } + } + } + } + ] + }, + { + "description": "partialResult is set when second operation fails during an unordered bulk write (verbose)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": false, + "verboseResults": true + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "0": { + "insertedId": 2 + } + }, + "updateResults": {}, + "deleteResults": {} + } + } + } + ] + }, + { + "description": "partialResult is set when first operation fails during an unordered bulk write (summary)", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "insertOne": { + "namespace": "crud-tests.coll0", + "document": { + "_id": 1, + "x": 11 + } + } + } + ], + "ordered": false, + "verboseResults": false + }, + "expectError": { + "expectResult": { + "insertedCount": 1, + "upsertedCount": 0, + "matchedCount": 0, + "modifiedCount": 0, + "deletedCount": 0, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } + } + } + } + ] + } + ] +} diff --git a/test/crud/unified/client-bulkWrite-replaceOne-sort.json b/test/crud/unified/client-bulkWrite-replaceOne-sort.json new file mode 100644 index 0000000000..53218c1f48 --- /dev/null +++ b/test/crud/unified/client-bulkWrite-replaceOne-sort.json @@ -0,0 +1,162 @@ +{ + "description": "client bulkWrite updateOne-sort", + "schemaVersion": "1.4", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "_yamlAnchors": { + "namespace": "crud-tests.coll0" + }, + "tests": [ + { + "description": "client bulkWrite replaceOne with sort option", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "replaceOne": { + "namespace": "crud-tests.coll0", + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "replacement": { + "x": 1 + } + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "bulkWrite", + "databaseName": "admin", + "command": { + "bulkWrite": 1, + "ops": [ + { + "update": 0, + "filter": { + "_id": { + "$gt": 1 + } + }, + "updateMods": { + "x": 1 + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ], + "nsInfo": [ + { + "ns": "crud-tests.coll0" + } + ] + } + } + }, + { + "commandSucceededEvent": { + "reply": { + "ok": 1, + "nErrors": 0, + "nMatched": 1, + "nModified": 1 + }, + "commandName": "bulkWrite" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 1 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/client-bulkWrite-updateOne-sort.json b/test/crud/unified/client-bulkWrite-updateOne-sort.json new file mode 100644 index 0000000000..4a07b8b97c --- /dev/null +++ b/test/crud/unified/client-bulkWrite-updateOne-sort.json @@ -0,0 +1,166 @@ +{ + "description": "client bulkWrite updateOne-sort", + "schemaVersion": "1.4", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "_yamlAnchors": { + "namespace": "crud-tests.coll0" + }, + "tests": [ + { + "description": "client bulkWrite updateOne with sort option", + "operations": [ + { + "object": "client0", + "name": "clientBulkWrite", + "arguments": { + "models": [ + { + "updateOne": { + "namespace": "crud-tests.coll0", + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "bulkWrite", + "databaseName": "admin", + "command": { + "bulkWrite": 1, + "ops": [ + { + "update": 0, + "filter": { + "_id": { + "$gt": 1 + } + }, + "updateMods": { + "$inc": { + "x": 1 + } + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ], + "nsInfo": [ + { + "ns": "crud-tests.coll0" + } + ] + } + } + }, + { + "commandSucceededEvent": { + "reply": { + "ok": 1, + "nErrors": 0, + "nMatched": 1, + "nModified": 1 + }, + "commandName": "bulkWrite" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 34 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/db-aggregate-write-readPreference.json b/test/crud/unified/db-aggregate-write-readPreference.json index 2a81282de8..b6460f001f 100644 --- a/test/crud/unified/db-aggregate-write-readPreference.json +++ b/test/crud/unified/db-aggregate-write-readPreference.json @@ -52,13 +52,6 @@ } } ], - "initialData": [ - { - "collectionName": "coll0", - "databaseName": "db0", - "documents": [] - } - ], "tests": [ { "description": "Database-level aggregate with $out includes read preference for 5.0+ server", @@ -141,17 +134,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll0", - "databaseName": "db0", - "documents": [ - { - "_id": 1 - } - ] - } ] }, { @@ -235,17 +217,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll0", - "databaseName": "db0", - "documents": [ - { - "_id": 1 - } - ] - } ] }, { @@ -332,17 +303,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll0", - "databaseName": "db0", - "documents": [ - { - "_id": 1 - } - ] - } ] }, { @@ -429,17 +389,6 @@ } ] } - ], - "outcome": [ - { - "collectionName": "coll0", - "databaseName": "db0", - "documents": [ - { - "_id": 1 - } - ] - } ] } ] diff --git a/test/crud/unified/replaceOne-sort.json b/test/crud/unified/replaceOne-sort.json new file mode 100644 index 0000000000..cf2271dda5 --- /dev/null +++ b/test/crud/unified/replaceOne-sort.json @@ -0,0 +1,232 @@ +{ + "description": "replaceOne-sort", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "tests": [ + { + "description": "ReplaceOne with sort option", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "name": "replaceOne", + "object": "collection0", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "replacement": { + "x": 1 + } + }, + "expectResult": { + "matchedCount": 1, + "modifiedCount": 1, + "upsertedCount": 0 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "x": 1 + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + }, + { + "commandSucceededEvent": { + "reply": { + "ok": 1, + "n": 1 + }, + "commandName": "update" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 1 + } + ] + } + ] + }, + { + "description": "replaceOne with sort option unsupported (server-side error)", + "runOnRequirements": [ + { + "maxServerVersion": "7.99" + } + ], + "operations": [ + { + "name": "replaceOne", + "object": "collection0", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "replacement": { + "x": 1 + } + }, + "expectError": { + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "x": 1 + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/updateOne-sort.json b/test/crud/unified/updateOne-sort.json new file mode 100644 index 0000000000..8fe4f50b94 --- /dev/null +++ b/test/crud/unified/updateOne-sort.json @@ -0,0 +1,240 @@ +{ + "description": "updateOne-sort", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "tests": [ + { + "description": "UpdateOne with sort option", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "name": "updateOne", + "object": "collection0", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "update": { + "$inc": { + "x": 1 + } + } + }, + "expectResult": { + "matchedCount": 1, + "modifiedCount": 1, + "upsertedCount": 0 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "$inc": { + "x": 1 + } + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + }, + { + "commandSucceededEvent": { + "reply": { + "ok": 1, + "n": 1 + }, + "commandName": "update" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 34 + } + ] + } + ] + }, + { + "description": "updateOne with sort option unsupported (server-side error)", + "runOnRequirements": [ + { + "maxServerVersion": "7.99" + } + ], + "operations": [ + { + "name": "updateOne", + "object": "collection0", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "sort": { + "_id": -1 + }, + "update": { + "$inc": { + "x": 1 + } + } + }, + "expectError": { + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll0", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "$inc": { + "x": 1 + } + }, + "sort": { + "_id": -1 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + } + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "coll0", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ] + } + ] +} diff --git a/test/utils.py b/test/utils.py index 4575a9fe10..3eac4fa509 100644 --- a/test/utils.py +++ b/test/utils.py @@ -958,10 +958,6 @@ def parse_spec_options(opts): def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): for arg_name in list(arguments): c2s = camel_to_snake(arg_name) - # PyMongo accepts sort as list of tuples. - if arg_name == "sort": - sort_dict = arguments[arg_name] - arguments[arg_name] = list(sort_dict.items()) # Named "key" instead not fieldName. if arg_name == "fieldName": arguments["key"] = arguments.pop(arg_name) From 8ce21bc1217f4c3c722a50d81c08ec33ca9d46dc Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 17 Oct 2024 09:18:01 -0500 Subject: [PATCH 25/35] PYTHON-4872 Use shrub.py to generate encryption tasks (#1938) --- .evergreen/config.yml | 387 ++++++++++++++++++++------ .evergreen/scripts/generate_config.py | 88 +++++- 2 files changed, 371 insertions(+), 104 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index c3427e66d0..54a1ff3368 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2322,32 +2322,6 @@ axes: variables: COVERAGE: "coverage" - # Run encryption tests? - - id: encryption - display_name: "Encryption" - values: - - id: "encryption" - display_name: "Encryption" - tags: ["encryption_tag"] - variables: - test_encryption: true - batchtime: 10080 # 7 days - - id: "encryption_pyopenssl" - display_name: "Encryption PyOpenSSL" - tags: ["encryption_tag"] - variables: - test_encryption: true - test_encryption_pyopenssl: true - batchtime: 10080 # 7 days - # The path to crypt_shared is stored in the $CRYPT_SHARED_LIB_PATH expansion. - - id: "encryption_crypt_shared" - display_name: "Encryption shared lib" - tags: ["encryption_tag"] - variables: - test_encryption: true - test_crypt_shared: true - batchtime: 10080 # 7 days - # Run pyopenssl tests? - id: pyopenssl display_name: "PyOpenSSL" @@ -2864,6 +2838,303 @@ buildvariants: PYTHON_BINARY: C:/python/32/Python39/python.exe SKIP_CSOT_TESTS: "true" +# Encryption tests. +- name: encryption-rhel8-py3.9-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption RHEL8 py3.9 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [encryption_tag] +- name: encryption-rhel8-py3.13-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption RHEL8 py3.13 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [encryption_tag] +- name: encryption-rhel8-pypy3.10-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption RHEL8 pypy3.10 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [encryption_tag] +- name: encryption-crypt_shared-rhel8-py3.9-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption crypt_shared RHEL8 py3.9 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [encryption_tag] +- name: encryption-crypt_shared-rhel8-py3.13-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption crypt_shared RHEL8 py3.13 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [encryption_tag] +- name: encryption-crypt_shared-rhel8-pypy3.10-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption crypt_shared RHEL8 pypy3.10 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [encryption_tag] +- name: encryption-pyopenssl-rhel8-py3.9-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption PyOpenSSL RHEL8 py3.9 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_encryption_pyopenssl: "true" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [encryption_tag] +- name: encryption-pyopenssl-rhel8-py3.13-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption PyOpenSSL RHEL8 py3.13 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_encryption_pyopenssl: "true" + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [encryption_tag] +- name: encryption-pyopenssl-rhel8-pypy3.10-auth-ssl + tasks: + - name: .standalone + - name: .replica_set + - name: .sharded_cluster + display_name: Encryption PyOpenSSL RHEL8 pypy3.10 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_encryption_pyopenssl: "true" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [encryption_tag] +- name: encryption-rhel8-py3.10-auth-ssl + tasks: + - name: .replica_set + display_name: Encryption RHEL8 py3.10 Auth SSL + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: encryption-crypt_shared-rhel8-py3.11-auth-nossl + tasks: + - name: .replica_set + display_name: Encryption crypt_shared RHEL8 py3.11 Auth NoSSL + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: nossl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: encryption-pyopenssl-rhel8-py3.12-auth-ssl + tasks: + - name: .replica_set + display_name: Encryption PyOpenSSL RHEL8 py3.12 Auth SSL + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + TEST_ENCRYPTION_PYOPENSSL: "true" + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: encryption-rhel8-pypy3.9-auth-nossl + tasks: + - name: .replica_set + display_name: Encryption RHEL8 pypy3.9 Auth NoSSL + run_on: + - rhel87-small + expansions: + AUTH: auth + SSL: nossl + test_encryption: "true" + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: encryption-macos-py3.9-auth-ssl + tasks: + - name: .latest .replica_set + display_name: Encryption macOS py3.9 Auth SSL + run_on: + - macos-14 + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + tags: [encryption_tag] +- name: encryption-macos-py3.13-auth-nossl + tasks: + - name: .latest .replica_set + display_name: Encryption macOS py3.13 Auth NoSSL + run_on: + - macos-14 + batchtime: 10080 + expansions: + AUTH: auth + SSL: nossl + test_encryption: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + tags: [encryption_tag] +- name: encryption-crypt_shared-macos-py3.9-auth-ssl + tasks: + - name: .latest .replica_set + display_name: Encryption crypt_shared macOS py3.9 Auth SSL + run_on: + - macos-14 + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + tags: [encryption_tag] +- name: encryption-crypt_shared-macos-py3.13-auth-nossl + tasks: + - name: .latest .replica_set + display_name: Encryption crypt_shared macOS py3.13 Auth NoSSL + run_on: + - macos-14 + batchtime: 10080 + expansions: + AUTH: auth + SSL: nossl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + tags: [encryption_tag] +- name: encryption-win64-py3.9-auth-ssl + tasks: + - name: .latest .replica_set + display_name: Encryption Win64 py3.9 Auth SSL + run_on: + - windows-64-vsMulti-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + PYTHON_BINARY: C:/python/Python39/python.exe + tags: [encryption_tag] +- name: encryption-win64-py3.13-auth-nossl + tasks: + - name: .latest .replica_set + display_name: Encryption Win64 py3.13 Auth NoSSL + run_on: + - windows-64-vsMulti-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: nossl + test_encryption: "true" + PYTHON_BINARY: C:/python/Python313/python.exe + tags: [encryption_tag] +- name: encryption-crypt_shared-win64-py3.9-auth-ssl + tasks: + - name: .latest .replica_set + display_name: Encryption crypt_shared Win64 py3.9 Auth SSL + run_on: + - windows-64-vsMulti-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: ssl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: C:/python/Python39/python.exe + tags: [encryption_tag] +- name: encryption-crypt_shared-win64-py3.13-auth-nossl + tasks: + - name: .latest .replica_set + display_name: Encryption crypt_shared Win64 py3.13 Auth NoSSL + run_on: + - windows-64-vsMulti-small + batchtime: 10080 + expansions: + AUTH: auth + SSL: nossl + test_encryption: "true" + test_crypt_shared: "true" + PYTHON_BINARY: C:/python/Python313/python.exe + tags: [encryption_tag] + - matrix_name: "tests-fips" matrix_spec: platform: @@ -2874,33 +3145,6 @@ buildvariants: tasks: - "test-fips-standalone" -- matrix_name: "test-macos-encryption" - matrix_spec: - platform: - - macos - auth: "auth" - ssl: "nossl" - encryption: "*" - display_name: "${encryption} ${platform} ${auth} ${ssl}" - tasks: "test-latest-replica_set" - rules: - - if: - encryption: ["encryption", "encryption_crypt_shared"] - platform: macos - auth: "auth" - ssl: "nossl" - then: - add_tasks: &encryption-server-versions - - ".rapid" - - ".latest" - - ".8.0" - - ".7.0" - - ".6.0" - - ".5.0" - - ".4.4" - - ".4.2" - - ".4.0" - # Test one server version with zSeries, POWER8, and ARM. - matrix_name: "test-different-cpu-architectures" matrix_spec: @@ -2954,26 +3198,6 @@ buildvariants: tasks: - '.replica_set' -- matrix_name: "tests-python-version-rhel8-test-encryption" - matrix_spec: - platform: rhel8 - python-version: "*" - auth-ssl: noauth-nossl -# TODO: dependency error for 'coverage-report' task: -# dependency tests-python-version-rhel62-test-encryption_.../test-2.6-standalone is not present in the project config -# coverage: "*" - encryption: "*" - display_name: "${encryption} ${python-version} ${platform} ${auth-ssl}" - tasks: "test-latest-replica_set" - rules: - - if: - encryption: ["encryption", "encryption_crypt_shared"] - platform: rhel8 - auth-ssl: noauth-nossl - python-version: "*" - then: - add_tasks: *encryption-server-versions - - matrix_name: "tests-python-version-rhel8-without-c-extensions" matrix_spec: platform: rhel8 @@ -3057,23 +3281,6 @@ buildvariants: tasks: - ".5.0" -- matrix_name: "tests-windows-encryption" - matrix_spec: - platform: windows - python-version-windows: "*" - auth-ssl: "*" - encryption: "*" - display_name: "${encryption} ${platform} ${python-version-windows} ${auth-ssl}" - tasks: "test-latest-replica_set" - rules: - - if: - encryption: ["encryption", "encryption_crypt_shared"] - platform: windows - python-version-windows: "*" - auth-ssl: "*" - then: - add_tasks: *encryption-server-versions - # Storage engine tests on RHEL 8.4 (x86_64) with Python 3.9. - matrix_name: "tests-storage-engines" matrix_spec: diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index 044303ad8f..dcd97b093e 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -45,18 +45,13 @@ class Host: name: str run_on: str display_name: str - expansions: dict[str, str] -_macos_expansions = dict( # CSOT tests are unreliable on slow hosts. - SKIP_CSOT_TESTS="true" -) - -HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8", dict()) -HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64", _macos_expansions) -HOSTS["win32"] = Host("win32", "windows-64-vsMulti-small", "Win32", _macos_expansions) -HOSTS["macos"] = Host("macos", "macos-14", "macOS", _macos_expansions) -HOSTS["macos-arm64"] = Host("macos-arm64", "macos-14-arm64", "macOS Arm64", _macos_expansions) +HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8") +HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64") +HOSTS["win32"] = Host("win32", "windows-64-vsMulti-small", "Win32") +HOSTS["macos"] = Host("macos", "macos-14", "macOS") +HOSTS["macos-arm64"] = Host("macos-arm64", "macos-14-arm64", "macOS Arm64") ############## @@ -84,7 +79,6 @@ def create_variant( expansions["PYTHON_BINARY"] = get_python_binary(python, host) if version: expansions["VERSION"] = version - expansions.update(HOSTS[host].expansions) expansions = expansions or None return BuildVariant( name=name, @@ -129,7 +123,7 @@ def get_display_name(base: str, host: str, **kwargs) -> str: elif key.lower() in DISPLAY_LOOKUP: name = DISPLAY_LOOKUP[key.lower()][value] else: - raise ValueError(f"Missing display handling for {key}") + continue display_name = f"{display_name} {name}" return display_name @@ -235,7 +229,7 @@ def create_server_variants() -> list[BuildVariant]: zip_cycle(MIN_MAX_PYTHON, AUTH_SSLS, TOPOLOGIES), SYNCS ): test_suite = "default" if sync == "sync" else "default_async" - expansions = dict(AUTH=auth, SSL=ssl, TEST_SUITES=test_suite) + expansions = dict(AUTH=auth, SSL=ssl, TEST_SUITES=test_suite, SKIP_CSOT_TESTS="true") display_name = get_display_name("Test", host, python=python, **expansions) variant = create_variant( [f".{topology}"], @@ -249,8 +243,74 @@ def create_server_variants() -> list[BuildVariant]: return variants +def create_encryption_variants() -> list[BuildVariant]: + variants = [] + tags = ["encryption_tag"] + batchtime = BATCHTIME_WEEK + + def get_encryption_expansions(encryption, ssl="ssl"): + expansions = dict(AUTH="auth", SSL=ssl, test_encryption="true") + if "crypt_shared" in encryption: + expansions["test_crypt_shared"] = "true" + if "PyOpenSSL" in encryption: + expansions["test_encryption_pyopenssl"] = "true" + return expansions + + host = "rhel8" + + # Test against all server versions and topolgies for the three main python versions. + encryptions = ["Encryption", "Encryption crypt_shared", "Encryption PyOpenSSL"] + for encryption, python in product(encryptions, [*MIN_MAX_PYTHON, PYPYS[-1]]): + expansions = get_encryption_expansions(encryption) + display_name = get_display_name(encryption, host, python=python, **expansions) + variant = create_variant( + [f".{t}" for t in TOPOLOGIES], + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + tags=tags, + ) + variants.append(variant) + + # Test the rest of the pythons on linux for all server versions. + for encryption, python, ssl in zip_cycle( + encryptions, CPYTHONS[1:-1] + PYPYS[:-1], ["ssl", "nossl"] + ): + expansions = get_encryption_expansions(encryption, ssl) + display_name = get_display_name(encryption, host, python=python, **expansions) + variant = create_variant( + [".replica_set"], + display_name, + python=python, + host=host, + expansions=expansions, + ) + variants.append(variant) + + # Test on macos and linux on one server version and topology for min and max python. + encryptions = ["Encryption", "Encryption crypt_shared"] + task_names = [".latest .replica_set"] + for host, encryption, python in product(["macos", "win64"], encryptions, MIN_MAX_PYTHON): + ssl = "ssl" if python == CPYTHONS[0] else "nossl" + expansions = get_encryption_expansions(encryption, ssl) + display_name = get_display_name(encryption, host, python=python, **expansions) + variant = create_variant( + task_names, + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + tags=tags, + ) + variants.append(variant) + return variants + + ################## # Generate Config ################## -generate_yaml(variants=create_server_variants()) +generate_yaml(variants=create_encryption_variants()) From a62ade864ddc07f1c0ee2782ef07ecfbf07fefd7 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 17 Oct 2024 11:32:39 -0400 Subject: [PATCH 26/35] PYTHON-4874 - Add KMS support for async Windows (#1939) --- pymongo/network_layer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index d14a21f41d..7a325853c8 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -205,7 +205,7 @@ async def _async_sendall_ssl( total_sent += sent async def _async_receive_ssl( - conn: _sslConn, length: int, dummy: AbstractEventLoop + conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 @@ -215,6 +215,9 @@ async def _async_receive_ssl( while total_read < length: try: read = conn.recv_into(mv[total_read:]) + # KMS responses update their expected size after the first batch, stop reading after one loop + if once: + return mv[:read] if read == 0: raise OSError("connection closed") except BLOCKING_IO_ERRORS: From 79033bc0b9a6e404dc7680a03d856f12942ec720 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 17 Oct 2024 10:33:44 -0500 Subject: [PATCH 27/35] Revert "PYTHON-4765 Resync server-selection spec" (#1940) --- .../operation-id.json | 4 +- .../server_selection_logging/replica-set.json | 2 +- test/server_selection_logging/sharded.json | 2 +- test/server_selection_logging/standalone.json | 930 +++++++++++++++++- 4 files changed, 932 insertions(+), 6 deletions(-) diff --git a/test/server_selection_logging/operation-id.json b/test/server_selection_logging/operation-id.json index 72ebff60d8..ccc2623166 100644 --- a/test/server_selection_logging/operation-id.json +++ b/test/server_selection_logging/operation-id.json @@ -197,7 +197,7 @@ } }, { - "level": "info", + "level": "debug", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", @@ -383,7 +383,7 @@ } }, { - "level": "info", + "level": "debug", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", diff --git a/test/server_selection_logging/replica-set.json b/test/server_selection_logging/replica-set.json index 5eba784bf2..830b1ea51a 100644 --- a/test/server_selection_logging/replica-set.json +++ b/test/server_selection_logging/replica-set.json @@ -184,7 +184,7 @@ } }, { - "level": "info", + "level": "debug", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", diff --git a/test/server_selection_logging/sharded.json b/test/server_selection_logging/sharded.json index d42fba9100..346c050f9e 100644 --- a/test/server_selection_logging/sharded.json +++ b/test/server_selection_logging/sharded.json @@ -193,7 +193,7 @@ } }, { - "level": "info", + "level": "debug", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", diff --git a/test/server_selection_logging/standalone.json b/test/server_selection_logging/standalone.json index 3b3eddd841..3152d0bbf3 100644 --- a/test/server_selection_logging/standalone.json +++ b/test/server_selection_logging/standalone.json @@ -47,9 +47,29 @@ } } ], + "initialData": [ + { + "collectionName": "server-selection", + "databaseName": "logging-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], "tests": [ { - "description": "A successful operation", + "description": "A successful insert operation", "operations": [ { "name": "waitForEvent", @@ -191,7 +211,7 @@ } }, { - "level": "info", + "level": "debug", "component": "serverSelection", "data": { "message": "Waiting for suitable server to become available", @@ -230,6 +250,912 @@ ] } ] + }, + { + "description": "A successful find operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "findOne", + "object": "collection", + "arguments": { + "filter": { + "x": 1 + } + } + } + + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "find", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "find", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful findAndModify operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "findOneAndReplace", + "object": "collection", + "arguments": { + "filter": { + "x": 1 + }, + "replacement": { + "x": 11 + } + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "findAndModify", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "findAndModify", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful find and getMore operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "find", + "object": "collection", + "arguments": { + "batchSize": 3 + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "find", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "find", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "getMore", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "getMore", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful aggregate operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "aggregate", + "object": "collection", + "arguments": { + "pipeline": [ + { + "$match": { + "_id": { + "$gt": 1 + } + } + } + ] + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "aggregate", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "aggregate", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful count operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "countDocuments", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "count", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "count", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful distinct operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "distinct", + "object": "collection", + "arguments": { + "fieldName": "x", + "filter": {} + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "distinct", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "distinct", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "Successful collection management operations", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "createCollection", + "object": "database", + "arguments": { + "collection": "foo" + } + }, + { + "name": "listCollections", + "object": "database" + }, + { + "name": "dropCollection", + "object": "database", + "arguments": { + "collection": "foo" + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "create", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "create", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "listCollections", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "listCollections", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "drop", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "drop", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "Successful index operations", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "createIndex", + "object": "collection", + "arguments": { + "keys": { + "x": 1 + }, + "name": "x_1" + } + }, + { + "name": "listIndexes", + "object": "collection" + }, + { + "name": "dropIndex", + "object": "collection", + "arguments": { + "name": "x_1" + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "createIndexes", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "createIndexes", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "listIndexes", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "listIndexes", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "dropIndexes", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "dropIndexes", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful update operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "updateOne", + "object": "collection", + "arguments": { + "filter": { + "x": 1 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "update", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "update", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] + }, + { + "description": "A successful delete operation", + "operations": [ + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "topologyDescriptionChangedEvent": {} + }, + "count": 2 + } + }, + { + "name": "deleteOne", + "object": "collection", + "arguments": { + "filter": { + "x": 1 + } + } + } + ], + "expectLogMessages": [ + { + "client": "client", + "messages": [ + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection started", + "selector": { + "$$exists": true + }, + "operation": "delete", + "topologyDescription": { + "$$exists": true + } + } + }, + { + "level": "debug", + "component": "serverSelection", + "data": { + "message": "Server selection succeeded", + "selector": { + "$$exists": true + }, + "operation": "delete", + "topologyDescription": { + "$$exists": true + }, + "serverHost": { + "$$type": "string" + }, + "serverPort": { + "$$type": [ + "int", + "long" + ] + } + } + } + ] + } + ] } ] } From 257aa2483be980005d4a54f97025baebcb3102f3 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 17 Oct 2024 11:01:47 -0500 Subject: [PATCH 28/35] PYTHON-4878 Use shrub.py for load balancer tests (#1941) --- .evergreen/config.yml | 218 +++++++++++++++++++++++--- .evergreen/scripts/generate_config.py | 38 ++++- 2 files changed, 230 insertions(+), 26 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 54a1ff3368..ae7c0a6590 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2354,16 +2354,6 @@ axes: variables: ORCHESTRATION_FILE: "versioned-api-testing.json" - # Run load balancer tests? - - id: loadbalancer - display_name: "Load Balancer" - values: - - id: "enabled" - display_name: "Load Balancer" - variables: - test_loadbalancer: true - batchtime: 10080 # 7 days - - id: serverless display_name: "Serverless" values: @@ -3580,6 +3570,203 @@ buildvariants: VERSION: "8.0" PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 +# Load balancer tests +- name: load-balancer-rhel8-v6.0-py3.9-auth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v6.0 py3.9 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "6.0" + AUTH: auth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.9/bin/python3 +- name: load-balancer-rhel8-v6.0-py3.10-noauth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v6.0 py3.10 NoAuth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "6.0" + AUTH: noauth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: load-balancer-rhel8-v6.0-py3.11-noauth-nossl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v6.0 py3.11 NoAuth NoSSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "6.0" + AUTH: noauth + SSL: nossl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: load-balancer-rhel8-v7.0-py3.12-auth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v7.0 py3.12 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "7.0" + AUTH: auth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: load-balancer-rhel8-v7.0-py3.13-noauth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v7.0 py3.13 NoAuth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "7.0" + AUTH: noauth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.13/bin/python3 +- name: load-balancer-rhel8-v7.0-pypy3.9-noauth-nossl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v7.0 pypy3.9 NoAuth NoSSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "7.0" + AUTH: noauth + SSL: nossl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: load-balancer-rhel8-v8.0-pypy3.10-auth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v8.0 pypy3.10 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "8.0" + AUTH: auth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 +- name: load-balancer-rhel8-v8.0-py3.9-noauth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v8.0 py3.9 NoAuth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "8.0" + AUTH: noauth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.9/bin/python3 +- name: load-balancer-rhel8-v8.0-py3.10-noauth-nossl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 v8.0 py3.10 NoAuth NoSSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: "8.0" + AUTH: noauth + SSL: nossl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: load-balancer-rhel8-latest-py3.11-auth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 latest py3.11 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: latest + AUTH: auth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: load-balancer-rhel8-latest-py3.12-noauth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 latest py3.12 NoAuth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: latest + AUTH: noauth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: load-balancer-rhel8-latest-py3.13-noauth-nossl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 latest py3.13 NoAuth NoSSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: latest + AUTH: noauth + SSL: nossl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.13/bin/python3 +- name: load-balancer-rhel8-rapid-pypy3.9-auth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 rapid pypy3.9 Auth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: rapid + AUTH: auth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: load-balancer-rhel8-rapid-pypy3.10-noauth-ssl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 rapid pypy3.10 NoAuth SSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: rapid + AUTH: noauth + SSL: ssl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 +- name: load-balancer-rhel8-rapid-py3.9-noauth-nossl + tasks: + - name: load-balancer-test + display_name: Load Balancer RHEL8 rapid py3.9 NoAuth NoSSL + run_on: + - rhel87-small + batchtime: 10080 + expansions: + VERSION: rapid + AUTH: noauth + SSL: nossl + test_loadbalancer: "true" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + - matrix_name: "oidc-auth-test" matrix_spec: platform: [ rhel8, macos, windows ] @@ -3643,17 +3830,6 @@ buildvariants: - name: "aws-auth-test-rapid" - name: "aws-auth-test-latest" -- matrix_name: "load-balancer" - matrix_spec: - platform: rhel8 - mongodb-version: ["6.0", "7.0", "8.0", "rapid", "latest"] - auth-ssl: "*" - python-version: "*" - loadbalancer: "*" - display_name: "Load Balancer ${platform} ${python-version} ${mongodb-version} ${auth-ssl}" - tasks: - - name: "load-balancer-test" - - name: testgcpkms-variant display_name: "GCP KMS" run_on: diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index dcd97b093e..03b900301c 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -112,12 +112,14 @@ def get_python_binary(python: str, host: str) -> str: def get_display_name(base: str, host: str, **kwargs) -> str: """Get the display name of a variant.""" display_name = f"{base} {HOSTS[host].display_name}" + version = kwargs.pop("VERSION", None) + if version: + if version not in ["rapid", "latest"]: + version = f"v{version}" + display_name = f"{display_name} {version}" for key, value in kwargs.items(): name = value - if key == "version": - if value not in ["rapid", "latest"]: - name = f"v{value}" - elif key == "python": + if key.lower() == "python": if not value.startswith("pypy"): name = f"py{value}" elif key.lower() in DISPLAY_LOOKUP: @@ -309,8 +311,34 @@ def get_encryption_expansions(encryption, ssl="ssl"): return variants +def create_load_balancer_variants(): + # Load balancer tests - run all supported versions for all combinations of auth and ssl and system python. + host = "rhel8" + task_names = ["load-balancer-test"] + batchtime = BATCHTIME_WEEK + expansions_base = dict(test_loadbalancer="true") + versions = ["6.0", "7.0", "8.0", "latest", "rapid"] + variants = [] + pythons = CPYTHONS + PYPYS + for ind, (version, (auth, ssl)) in enumerate(product(versions, AUTH_SSLS)): + expansions = dict(VERSION=version, AUTH=auth, SSL=ssl) + expansions.update(expansions_base) + python = pythons[ind % len(pythons)] + display_name = get_display_name("Load Balancer", host, python=python, **expansions) + variant = create_variant( + task_names, + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + variants.append(variant) + return variants + + ################## # Generate Config ################## -generate_yaml(variants=create_encryption_variants()) +generate_yaml(variants=create_load_balancer_variants()) From 317a539415a91a4057c9104a2192d05fb47af2e1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 17 Oct 2024 15:01:24 -0500 Subject: [PATCH 29/35] PYTHON-4879 Use shrub.py for compressor tests (#1944) --- .evergreen/config.yml | 135 ++++++++++++++++---------- .evergreen/scripts/generate_config.py | 51 +++++++++- 2 files changed, 136 insertions(+), 50 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index ae7c0a6590..a7efc223b9 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2112,23 +2112,6 @@ axes: AUTH: "noauth" SSL: "nossl" - # Choice of wire protocol compression support - - id: compression - display_name: Compression - values: - - id: snappy - display_name: snappy compression - variables: - COMPRESSORS: "snappy" - - id: zlib - display_name: zlib compression - variables: - COMPRESSORS: "zlib" - - id: zstd - display_name: zstd compression - variables: - COMPRESSORS: "zstd" - # Choice of MongoDB server version - id: mongodb-version display_name: "MongoDB" @@ -3125,6 +3108,92 @@ buildvariants: PYTHON_BINARY: C:/python/Python313/python.exe tags: [encryption_tag] +# Compressor tests. +- name: snappy-compression-rhel8-py3.9-no-c + tasks: + - name: .standalone + display_name: snappy compression RHEL8 py3.9 No C + run_on: + - rhel87-small + expansions: + COMPRESSORS: snappy + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.9/bin/python3 +- name: snappy-compression-rhel8-py3.10 + tasks: + - name: .standalone + display_name: snappy compression RHEL8 py3.10 + run_on: + - rhel87-small + expansions: + COMPRESSORS: snappy + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: zlib-compression-rhel8-py3.11-no-c + tasks: + - name: .standalone + display_name: zlib compression RHEL8 py3.11 No C + run_on: + - rhel87-small + expansions: + COMPRESSORS: zlib + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: zlib-compression-rhel8-py3.12 + tasks: + - name: .standalone + display_name: zlib compression RHEL8 py3.12 + run_on: + - rhel87-small + expansions: + COMPRESSORS: zlib + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: zstd-compression-rhel8-py3.13-no-c + tasks: + - name: .standalone !.4.0 + display_name: zstd compression RHEL8 py3.13 No C + run_on: + - rhel87-small + expansions: + COMPRESSORS: zstd + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.13/bin/python3 +- name: zstd-compression-rhel8-py3.9 + tasks: + - name: .standalone !.4.0 + display_name: zstd compression RHEL8 py3.9 + run_on: + - rhel87-small + expansions: + COMPRESSORS: zstd + PYTHON_BINARY: /opt/python/3.9/bin/python3 +- name: snappy-compression-rhel8-pypy3.9 + tasks: + - name: .standalone + display_name: snappy compression RHEL8 pypy3.9 + run_on: + - rhel87-small + expansions: + COMPRESSORS: snappy + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: zlib-compression-rhel8-pypy3.10 + tasks: + - name: .standalone + display_name: zlib compression RHEL8 pypy3.10 + run_on: + - rhel87-small + expansions: + COMPRESSORS: zlib + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 +- name: zstd-compression-rhel8-pypy3.9 + tasks: + - name: .standalone !.4.0 + display_name: zstd compression RHEL8 pypy3.9 + run_on: + - rhel87-small + expansions: + COMPRESSORS: zstd + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 + - matrix_name: "tests-fips" matrix_spec: platform: @@ -3214,38 +3283,6 @@ buildvariants: - ".4.2" - ".4.0" -- matrix_name: "tests-python-version-rhel8-compression" - matrix_spec: - platform: rhel8 - python-version: "*" - c-extensions: "*" - compression: "*" - exclude_spec: - # These interpreters are always tested without extensions. - - platform: rhel8 - python-version: ["pypy3.9", "pypy3.10"] - c-extensions: "with-c-extensions" - compression: "*" - display_name: "${compression} ${c-extensions} ${python-version} ${platform}" - tasks: - - "test-latest-standalone" - - "test-8.0-standalone" - - "test-7.0-standalone" - - "test-6.0-standalone" - - "test-5.0-standalone" - - "test-4.4-standalone" - - "test-4.2-standalone" - - "test-4.0-standalone" - rules: - # Server version 4.0 supports snappy and zlib but not zstd. - - if: - python-version: "*" - c-extensions: "*" - compression: ["zstd"] - then: - remove_tasks: - - "test-4.0-standalone" - - matrix_name: "tests-python-version-green-framework-rhel8" matrix_spec: platform: rhel8 diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index 03b900301c..91dedeb620 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -30,12 +30,14 @@ BATCHTIME_WEEK = 10080 AUTH_SSLS = [("auth", "ssl"), ("noauth", "ssl"), ("noauth", "nossl")] TOPOLOGIES = ["standalone", "replica_set", "sharded_cluster"] +C_EXTS = ["with_ext", "without_ext"] SYNCS = ["sync", "async"] DISPLAY_LOOKUP = dict( ssl=dict(ssl="SSL", nossl="NoSSL"), auth=dict(auth="Auth", noauth="NoAuth"), test_suites=dict(default="Sync", default_async="Async"), coverage=dict(coverage="cov"), + no_ext={"1": "No C"}, ) HOSTS = dict() @@ -137,6 +139,12 @@ def zip_cycle(*iterables, empty_default=None): yield tuple(next(i, empty_default) for i in cycles) +def handle_c_ext(c_ext, expansions): + """Handle c extension option.""" + if c_ext == C_EXTS[0]: + expansions["NO_EXT"] = "1" + + def generate_yaml(tasks=None, variants=None): """Generate the yaml for a given set of tasks and variants.""" project = EvgProject(tasks=tasks, buildvariants=variants) @@ -337,8 +345,49 @@ def create_load_balancer_variants(): return variants +def create_compression_variants(): + # Compression tests - standalone versions of each server, across python versions, with and without c extensions. + # PyPy interpreters are always tested without extensions. + host = "rhel8" + task_names = dict(snappy=[".standalone"], zlib=[".standalone"], zstd=[".standalone !.4.0"]) + variants = [] + for ind, (compressor, c_ext) in enumerate(product(["snappy", "zlib", "zstd"], C_EXTS)): + expansions = dict(COMPRESSORS=compressor) + handle_c_ext(c_ext, expansions) + base_name = f"{compressor} compression" + python = CPYTHONS[ind % len(CPYTHONS)] + display_name = get_display_name(base_name, host, python=python, **expansions) + variant = create_variant( + task_names[compressor], + display_name, + python=python, + host=host, + expansions=expansions, + ) + variants.append(variant) + + other_pythons = PYPYS + CPYTHONS[ind:] + for compressor, python in zip_cycle(["snappy", "zlib", "zstd"], other_pythons): + expansions = dict(COMPRESSORS=compressor) + handle_c_ext(c_ext, expansions) + base_name = f"{compressor} compression" + display_name = get_display_name(base_name, host, python=python, **expansions) + variant = create_variant( + task_names[compressor], + display_name, + python=python, + host=host, + expansions=expansions, + ) + variants.append(variant) + + return variants + + ################## # Generate Config ################## -generate_yaml(variants=create_load_balancer_variants()) +variants = create_compression_variants() +# print(len(variants)) +generate_yaml(variants=variants) From 7e904b3c31d3242ef9e202438e50cbecf611af75 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 17 Oct 2024 13:11:20 -0700 Subject: [PATCH 30/35] PYTHON-4874 Fix async Windows KMS support (#1942) --- pymongo/network_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 7a325853c8..aa16e85a07 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -215,11 +215,11 @@ async def _async_receive_ssl( while total_read < length: try: read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") # KMS responses update their expected size after the first batch, stop reading after one loop if once: return mv[:read] - if read == 0: - raise OSError("connection closed") except BLOCKING_IO_ERRORS: await asyncio.sleep(backoff) read = 0 From 335b728f070a350239123a755856a1a3d1d51746 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 20:27:27 -0500 Subject: [PATCH 31/35] Bump pyright from 1.1.383 to 1.1.384 (#1922) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jib --- requirements/typing.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/typing.txt b/requirements/typing.txt index 06c33c6db6..2c23212da7 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.2 -pyright==1.1.383 +pyright==1.1.384 typing_extensions -r ./encryption.txt -r ./ocsp.txt From 021a9f75243b466d2a333c991ad4d9eb52d8275e Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 18 Oct 2024 08:57:20 -0500 Subject: [PATCH 32/35] PYTHON-4882 Use shrub.py for enterprise auth tests (#1945) --- .evergreen/config.yml | 83 +++++++++++++++++++++------ .evergreen/run-tests.sh | 2 + .evergreen/scripts/generate_config.py | 23 +++++++- pyproject.toml | 1 + test/asynchronous/test_auth.py | 4 ++ test/test_auth.py | 4 ++ 6 files changed, 98 insertions(+), 19 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index a7efc223b9..9e2ab77088 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -3194,6 +3194,71 @@ buildvariants: COMPRESSORS: zstd PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +# Enterprise auth tests. +- name: enterprise-auth-macos-py3.9-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth macOS py3.9 Auth + run_on: + - macos-14 + expansions: + AUTH: auth + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 +- name: enterprise-auth-rhel8-py3.10-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth RHEL8 py3.10 Auth + run_on: + - rhel87-small + expansions: + AUTH: auth + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: enterprise-auth-rhel8-py3.11-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth RHEL8 py3.11 Auth + run_on: + - rhel87-small + expansions: + AUTH: auth + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: enterprise-auth-rhel8-py3.12-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth RHEL8 py3.12 Auth + run_on: + - rhel87-small + expansions: + AUTH: auth + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: enterprise-auth-win64-py3.13-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth Win64 py3.13 Auth + run_on: + - windows-64-vsMulti-small + expansions: + AUTH: auth + PYTHON_BINARY: C:/python/Python313/python.exe +- name: enterprise-auth-rhel8-pypy3.9-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth RHEL8 pypy3.9 Auth + run_on: + - rhel87-small + expansions: + AUTH: auth + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: enterprise-auth-rhel8-pypy3.10-auth + tasks: + - name: test-enterprise-auth + display_name: Enterprise Auth RHEL8 pypy3.10 Auth + run_on: + - rhel87-small + expansions: + AUTH: auth + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + - matrix_name: "tests-fips" matrix_spec: platform: @@ -3350,24 +3415,6 @@ buildvariants: tasks: - ".latest" -- matrix_name: "test-linux-enterprise-auth" - matrix_spec: - platform: rhel8 - python-version: "*" - auth: "auth" - display_name: "Enterprise ${auth} ${platform} ${python-version}" - tasks: - - name: "test-enterprise-auth" - -- matrix_name: "tests-windows-enterprise-auth" - matrix_spec: - platform: windows - python-version-windows: "*" - auth: "auth" - display_name: "Enterprise ${auth} ${platform} ${python-version-windows}" - tasks: - - name: "test-enterprise-auth" - - matrix_name: "test-search-index-helpers" matrix_spec: platform: rhel8 diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 364570999f..36fa76e317 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -90,6 +90,8 @@ if [ -n "$TEST_ENTERPRISE_AUTH" ]; then export GSSAPI_HOST=${SASL_HOST} export GSSAPI_PORT=${SASL_PORT} export GSSAPI_PRINCIPAL=${PRINCIPAL} + + export TEST_SUITES="auth" fi if [ -n "$TEST_LOADBALANCER" ]; then diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index 91dedeb620..c3bfeef7af 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -384,10 +384,31 @@ def create_compression_variants(): return variants +def create_enterprise_auth_variants(): + expansions = dict(AUTH="auth") + variants = [] + + # All python versions across platforms. + for python in ALL_PYTHONS: + if python == CPYTHONS[0]: + host = "macos" + elif python == CPYTHONS[-1]: + host = "win64" + else: + host = "rhel8" + display_name = get_display_name("Enterprise Auth", host, python=python, **expansions) + variant = create_variant( + ["test-enterprise-auth"], display_name, host=host, python=python, expansions=expansions + ) + variants.append(variant) + + return variants + + ################## # Generate Config ################## -variants = create_compression_variants() +variants = create_enterprise_auth_variants() # print(len(variants)) generate_yaml(variants=variants) diff --git a/pyproject.toml b/pyproject.toml index b4f59f67d5..9a29a777fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ filterwarnings = [ markers = [ "auth_aws: tests that rely on pymongo-auth-aws", "auth_oidc: tests that rely on oidc auth", + "auth: tests that rely on authentication", "ocsp: tests that rely on ocsp", "atlas: tests that rely on atlas", "data_lake: tests that rely on atlas data lake", diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index fbaca41f09..9262714374 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -32,6 +32,8 @@ ) from test.utils import AllowListEventListener, delay, ignore_deprecations +import pytest + from pymongo import AsyncMongoClient, monitoring from pymongo.asynchronous.auth import HAVE_KERBEROS from pymongo.auth_shared import _build_credentials_tuple @@ -42,6 +44,8 @@ _IS_SYNC = False +pytestmark = pytest.mark.auth + # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get("GSSAPI_HOST") GSSAPI_PORT = int(os.environ.get("GSSAPI_PORT", "27017")) diff --git a/test/test_auth.py b/test/test_auth.py index b311d330bc..310006afff 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -32,6 +32,8 @@ ) from test.utils import AllowListEventListener, delay, ignore_deprecations +import pytest + from pymongo import MongoClient, monitoring from pymongo.auth_shared import _build_credentials_tuple from pymongo.errors import OperationFailure @@ -42,6 +44,8 @@ _IS_SYNC = True +pytestmark = pytest.mark.auth + # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get("GSSAPI_HOST") GSSAPI_PORT = int(os.environ.get("GSSAPI_PORT", "27017")) From 6a7e83dc95319c445b95ab1f54f4aa8435986cc4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 18 Oct 2024 10:36:05 -0500 Subject: [PATCH 33/35] PYTHON-4887 Do not test macos arm64 on server versions < 6.0 (#1947) --- .evergreen/config.yml | 48 ++++++++++++++++++++------- .evergreen/scripts/generate_config.py | 9 +++-- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 9e2ab77088..705880be60 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2592,7 +2592,11 @@ buildvariants: # Server tests for macOS Arm64. - name: test-macos-arm64-py3.9-auth-ssl-sync tasks: - - name: .standalone + - name: .standalone .6.0 + - name: .standalone .7.0 + - name: .standalone .8.0 + - name: .standalone .rapid + - name: .standalone .latest display_name: Test macOS Arm64 py3.9 Auth SSL Sync run_on: - macos-14-arm64 @@ -2600,11 +2604,15 @@ buildvariants: AUTH: auth SSL: ssl TEST_SUITES: default - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 SKIP_CSOT_TESTS: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - name: test-macos-arm64-py3.9-auth-ssl-async tasks: - - name: .standalone + - name: .standalone .6.0 + - name: .standalone .7.0 + - name: .standalone .8.0 + - name: .standalone .rapid + - name: .standalone .latest display_name: Test macOS Arm64 py3.9 Auth SSL Async run_on: - macos-14-arm64 @@ -2612,11 +2620,15 @@ buildvariants: AUTH: auth SSL: ssl TEST_SUITES: default_async - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 SKIP_CSOT_TESTS: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - name: test-macos-arm64-py3.13-noauth-ssl-sync tasks: - - name: .replica_set + - name: .replica_set .6.0 + - name: .replica_set .7.0 + - name: .replica_set .8.0 + - name: .replica_set .rapid + - name: .replica_set .latest display_name: Test macOS Arm64 py3.13 NoAuth SSL Sync run_on: - macos-14-arm64 @@ -2624,11 +2636,15 @@ buildvariants: AUTH: noauth SSL: ssl TEST_SUITES: default - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 SKIP_CSOT_TESTS: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 - name: test-macos-arm64-py3.13-noauth-ssl-async tasks: - - name: .replica_set + - name: .replica_set .6.0 + - name: .replica_set .7.0 + - name: .replica_set .8.0 + - name: .replica_set .rapid + - name: .replica_set .latest display_name: Test macOS Arm64 py3.13 NoAuth SSL Async run_on: - macos-14-arm64 @@ -2636,11 +2652,15 @@ buildvariants: AUTH: noauth SSL: ssl TEST_SUITES: default_async - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 SKIP_CSOT_TESTS: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 - name: test-macos-arm64-py3.9-noauth-nossl-sync tasks: - - name: .sharded_cluster + - name: .sharded_cluster .6.0 + - name: .sharded_cluster .7.0 + - name: .sharded_cluster .8.0 + - name: .sharded_cluster .rapid + - name: .sharded_cluster .latest display_name: Test macOS Arm64 py3.9 NoAuth NoSSL Sync run_on: - macos-14-arm64 @@ -2648,11 +2668,15 @@ buildvariants: AUTH: noauth SSL: nossl TEST_SUITES: default - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 SKIP_CSOT_TESTS: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - name: test-macos-arm64-py3.9-noauth-nossl-async tasks: - - name: .sharded_cluster + - name: .sharded_cluster .6.0 + - name: .sharded_cluster .7.0 + - name: .sharded_cluster .8.0 + - name: .sharded_cluster .rapid + - name: .sharded_cluster .latest display_name: Test macOS Arm64 py3.9 NoAuth NoSSL Async run_on: - macos-14-arm64 @@ -2660,8 +2684,8 @@ buildvariants: AUTH: noauth SSL: nossl TEST_SUITES: default_async - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 SKIP_CSOT_TESTS: "true" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 # Server tests for Windows. - name: test-win64-py3.9-auth-ssl-sync diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index c3bfeef7af..a3ec798c3b 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -23,6 +23,7 @@ ############## ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] +VERSIONS_6_0_PLUS = ["6.0", "7.0", "8.0", "rapid", "latest"] CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] PYPYS = ["pypy3.9", "pypy3.10"] ALL_PYTHONS = CPYTHONS + PYPYS @@ -239,10 +240,14 @@ def create_server_variants() -> list[BuildVariant]: zip_cycle(MIN_MAX_PYTHON, AUTH_SSLS, TOPOLOGIES), SYNCS ): test_suite = "default" if sync == "sync" else "default_async" + tasks = [f".{topology}"] + # MacOS arm64 only works on server versions 6.0+ + if host == "macos-arm64": + tasks = [f".{topology} .{version}" for version in VERSIONS_6_0_PLUS] expansions = dict(AUTH=auth, SSL=ssl, TEST_SUITES=test_suite, SKIP_CSOT_TESTS="true") display_name = get_display_name("Test", host, python=python, **expansions) variant = create_variant( - [f".{topology}"], + tasks, display_name, python=python, host=host, @@ -409,6 +414,6 @@ def create_enterprise_auth_variants(): # Generate Config ################## -variants = create_enterprise_auth_variants() +variants = create_server_variants() # print(len(variants)) generate_yaml(variants=variants) From 1ae0c3904c9a71bed1f7d8f81183b59070845a81 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 18 Oct 2024 10:58:28 -0500 Subject: [PATCH 34/35] PYTHON-4886 Use shrub.py for PyOpenSSL tests (#1946) --- .evergreen/config.yml | 144 +++++++++++++++++--------- .evergreen/scripts/generate_config.py | 34 +++++- 2 files changed, 126 insertions(+), 52 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 705880be60..9083da145b 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2305,16 +2305,6 @@ axes: variables: COVERAGE: "coverage" - # Run pyopenssl tests? - - id: pyopenssl - display_name: "PyOpenSSL" - values: - - id: "enabled" - display_name: "PyOpenSSL" - variables: - test_pyopenssl: true - batchtime: 10080 # 7 days - - id: versionedApi display_name: "versionedApi" values: @@ -3283,6 +3273,99 @@ buildvariants: AUTH: auth PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 +# PyOpenSSL tests. +- name: pyopenssl-macos-py3.9 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL macOS py3.9 + run_on: + - macos-14 + batchtime: 10080 + expansions: + AUTH: noauth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 +- name: pyopenssl-rhel8-py3.10 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL RHEL8 py3.10 + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: pyopenssl-rhel8-py3.11 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL RHEL8 py3.11 + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: pyopenssl-rhel8-py3.12 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL RHEL8 py3.12 + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: pyopenssl-win64-py3.13 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL Win64 py3.13 + run_on: + - windows-64-vsMulti-small + batchtime: 10080 + expansions: + AUTH: auth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: C:/python/Python313/python.exe +- name: pyopenssl-rhel8-pypy3.9 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL RHEL8 pypy3.9 + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: pyopenssl-rhel8-pypy3.10 + tasks: + - name: .replica_set + - name: .7.0 + display_name: PyOpenSSL RHEL8 pypy3.10 + run_on: + - rhel87-small + batchtime: 10080 + expansions: + AUTH: auth + test_pyopenssl: "true" + SSL: ssl + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + - matrix_name: "tests-fips" matrix_spec: platform: @@ -3305,47 +3388,6 @@ buildvariants: tasks: - ".6.0" -- matrix_name: "tests-pyopenssl" - matrix_spec: - platform: rhel8 - python-version: "*" - auth: "*" - ssl: "ssl" - pyopenssl: "*" - # Only test "noauth" with Python 3.9. - exclude_spec: - platform: rhel8 - python-version: ["3.10", "3.11", "3.12", "3.13", "pypy3.9", "pypy3.10"] - auth: "noauth" - ssl: "ssl" - pyopenssl: "*" - display_name: "PyOpenSSL ${platform} ${python-version} ${auth}" - tasks: - - '.replica_set' - # Test standalone and sharded only on 7.0. - - '.7.0' - -- matrix_name: "tests-pyopenssl-macOS" - matrix_spec: - platform: macos - auth: "auth" - ssl: "ssl" - pyopenssl: "*" - display_name: "PyOpenSSL ${platform} ${auth}" - tasks: - - '.replica_set' - -- matrix_name: "tests-pyopenssl-windows" - matrix_spec: - platform: windows - python-version-windows: "*" - auth: "auth" - ssl: "ssl" - pyopenssl: "*" - display_name: "PyOpenSSL ${platform} ${python-version-windows} ${auth}" - tasks: - - '.replica_set' - - matrix_name: "tests-python-version-rhel8-without-c-extensions" matrix_spec: platform: rhel8 diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index a3ec798c3b..6d614a9afe 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -410,10 +410,42 @@ def create_enterprise_auth_variants(): return variants +def create_pyopenssl_variants(): + base_name = "PyOpenSSL" + batchtime = BATCHTIME_WEEK + base_expansions = dict(test_pyopenssl="true", SSL="ssl") + variants = [] + + for python in ALL_PYTHONS: + # Only test "noauth" with min python. + auth = "noauth" if python == CPYTHONS[0] else "auth" + if python == CPYTHONS[0]: + host = "macos" + elif python == CPYTHONS[-1]: + host = "win64" + else: + host = "rhel8" + expansions = dict(AUTH=auth) + expansions.update(base_expansions) + + display_name = get_display_name(base_name, host, python=python) + variant = create_variant( + [".replica_set", ".7.0"], + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + variants.append(variant) + + return variants + + ################## # Generate Config ################## -variants = create_server_variants() +variants = create_pyopenssl_variants() # print(len(variants)) generate_yaml(variants=variants) From a1ade45dd3b67a6e4baa50404b9807f68062f043 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 18 Oct 2024 13:32:09 -0400 Subject: [PATCH 35/35] PYTHON-4881 - Use OvertCommandListener wherever sensitive events are not needed (#1943) Co-authored-by: Steven Silvester --- test/asynchronous/test_change_stream.py | 5 +++-- test/asynchronous/test_collation.py | 4 ++-- test/asynchronous/test_collection.py | 3 ++- test/asynchronous/test_cursor.py | 4 ++-- test/asynchronous/test_grid_file.py | 4 ++-- test/asynchronous/test_monitoring.py | 21 ++++++++++++--------- test/asynchronous/test_session.py | 3 ++- test/auth_oidc/test_auth_oidc.py | 6 +++--- test/test_change_stream.py | 5 +++-- test/test_collation.py | 4 ++-- test/test_collection.py | 3 ++- test/test_cursor.py | 4 ++-- test/test_grid_file.py | 4 ++-- test/test_index_management.py | 4 ++-- test/test_monitoring.py | 21 ++++++++++++--------- test/test_read_write_concern_spec.py | 6 +++--- test/test_server_selection.py | 3 ++- test/test_session.py | 3 ++- test/test_ssl.py | 1 + 19 files changed, 61 insertions(+), 47 deletions(-) diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 883ed72c4c..98641f46ee 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -39,6 +39,7 @@ from test.utils import ( AllowListEventListener, EventListener, + OvertCommandListener, async_wait_until, ) @@ -179,7 +180,7 @@ async def _wait_until(): @no_type_check async def test_try_next_runs_one_getmore(self): - listener = EventListener() + listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") @@ -237,7 +238,7 @@ async def _wait_until(): @no_type_check async def test_batch_size_is_honored(self): - listener = EventListener() + listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index be3ea22e42..d95f4c9917 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import EventListener +from test.utils import EventListener, OvertCommandListener from typing import Any from pymongo.asynchronous.helpers import anext @@ -101,7 +101,7 @@ class TestCollation(AsyncIntegrationTest): @async_client_context.require_connection async def _setup_class(cls): await super()._setup_class() - cls.listener = EventListener() + cls.listener = OvertCommandListener() cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test cls.collation = Collation("en_US") diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 612090b69f..db52bad4ac 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -36,6 +36,7 @@ from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, + OvertCommandListener, async_get_pool, async_is_mongos, async_wait_until, @@ -2116,7 +2117,7 @@ async def test_find_one_and(self): self.assertEqual(4, (await c.find_one_and_update({}, {"$inc": {"i": 1}}, sort=sort))["j"]) async def test_find_one_and_write_concern(self): - listener = EventListener() + listener = OvertCommandListener() db = (await self.async_single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index ee0a757ed3..787da3d957 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1601,7 +1601,7 @@ async def test_read_concern(self): await anext(c.find_raw_batches()) async def test_monitoring(self): - listener = EventListener() + listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() @@ -1768,7 +1768,7 @@ async def test_collation(self): await anext(await self.db.test.aggregate_raw_batches([], collation=Collation("en_US"))) async def test_monitoring(self): - listener = EventListener() + listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 9c57c15c5a..54fcd3abf6 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import EventListener +from test.utils import OvertCommandListener from bson.objectid import ObjectId from gridfs.asynchronous.grid_file import ( @@ -810,7 +810,7 @@ async def test_survive_cursor_not_found(self): # Use 102 batches to cause a single getMore. chunk_size = 1024 data = b"d" * (102 * chunk_size) - listener = EventListener() + listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile: diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index b5d8708dc3..b0c86ab54e 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -31,6 +31,7 @@ ) from test.utils import ( EventListener, + OvertCommandListener, async_wait_until, ) @@ -54,7 +55,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): @async_client_context.require_connection async def _setup_class(cls): await super()._setup_class() - cls.listener = EventListener() + cls.listener = OvertCommandListener() cls.client = await cls.unmanaged_async_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False ) @@ -1100,11 +1101,13 @@ async def test_first_batch_helper(self): @async_client_context.require_version_max(6, 1, 99) async def test_sensitive_commands(self): - listeners = self.client._event_listeners + listener = EventListener() + client = await self.async_rs_or_single_client(event_listeners=[listener]) + listeners = client._event_listeners - self.listener.reset() + listener.reset() cmd = SON([("getnonce", 1)]) - listeners.publish_command_start(cmd, "pymongo_test", 12345, await self.client.address, None) # type: ignore[arg-type] + listeners.publish_command_start(cmd, "pymongo_test", 12345, await client.address, None) # type: ignore[arg-type] delta = datetime.timedelta(milliseconds=100) listeners.publish_command_success( delta, @@ -1115,15 +1118,15 @@ async def test_sensitive_commands(self): None, database_name="pymongo_test", ) - started = self.listener.started_events[0] - succeeded = self.listener.succeeded_events[0] - self.assertEqual(0, len(self.listener.failed_events)) + started = listener.started_events[0] + succeeded = listener.succeeded_events[0] + self.assertEqual(0, len(listener.failed_events)) self.assertIsInstance(started, monitoring.CommandStartedEvent) self.assertEqual({}, started.command) self.assertEqual("pymongo_test", started.database_name) self.assertEqual("getnonce", started.command_name) self.assertIsInstance(started.request_id, int) - self.assertEqual(await self.client.address, started.connection_id) + self.assertEqual(await client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertEqual(succeeded.duration_micros, 100000) self.assertEqual(started.command_name, succeeded.command_name) @@ -1140,7 +1143,7 @@ class AsyncTestGlobalListener(AsyncIntegrationTest): @async_client_context.require_connection async def _setup_class(cls): await super()._setup_class() - cls.listener = EventListener() + cls.listener = OvertCommandListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index d264b5ecb0..b432621798 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -36,6 +36,7 @@ from test.utils import ( EventListener, ExceptionCatchingThread, + OvertCommandListener, async_wait_until, wait_until, ) @@ -199,7 +200,7 @@ def test_implicit_sessions_checkout(self): lsid_set = set() failures = 0 for _ in range(5): - listener = EventListener() + listener = OvertCommandListener() client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 6d31f3db4e..6526391daf 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.unified_format import generate_test_classes -from test.utils import EventListener +from test.utils import EventListener, OvertCommandListener from bson import SON from pymongo import MongoClient @@ -348,7 +348,7 @@ def test_4_1_reauthenticate_succeeds(self): # Create a default OIDC client and add an event listener. # The following assumes that the driver does not emit saslStart or saslContinue events. # If the driver does emit those events, ignore/filter them for the purposes of this test. - listener = EventListener() + listener = OvertCommandListener() client = self.create_client(event_listeners=[listener]) # Perform a find operation that succeeds. @@ -1021,7 +1021,7 @@ def fetch(self, _): def test_4_4_speculative_authentication_should_be_ignored_on_reauthentication(self): # Create an OIDC configured client that can listen for `SaslStart` commands. - listener = EventListener() + listener = OvertCommandListener() client = self.create_client(event_listeners=[listener]) # Preload the *Client Cache* with a valid access token to enforce Speculative Authentication. diff --git a/test/test_change_stream.py b/test/test_change_stream.py index dae224c5e0..3a107122b7 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -39,6 +39,7 @@ from test.utils import ( AllowListEventListener, EventListener, + OvertCommandListener, wait_until, ) @@ -177,7 +178,7 @@ def _wait_until(): @no_type_check def test_try_next_runs_one_getmore(self): - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") @@ -235,7 +236,7 @@ def _wait_until(): @no_type_check def test_batch_size_is_honored(self): - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") diff --git a/test/test_collation.py b/test/test_collation.py index e5c1c7eb11..b878df2fb4 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener +from test.utils import EventListener, OvertCommandListener from typing import Any from pymongo.collation import ( @@ -101,7 +101,7 @@ class TestCollation(IntegrationTest): @client_context.require_connection def _setup_class(cls): super()._setup_class() - cls.listener = EventListener() + cls.listener = OvertCommandListener() cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test cls.collation = Collation("en_US") diff --git a/test/test_collection.py b/test/test_collection.py index a2c3b0b0b6..84a900d45b 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -36,6 +36,7 @@ from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, + OvertCommandListener, get_pool, is_mongos, wait_until, @@ -2093,7 +2094,7 @@ def test_find_one_and(self): self.assertEqual(4, (c.find_one_and_update({}, {"$inc": {"i": 1}}, sort=sort))["j"]) def test_find_one_and_write_concern(self): - listener = EventListener() + listener = OvertCommandListener() db = (self.single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) diff --git a/test/test_cursor.py b/test/test_cursor.py index 7a6dfc9429..9eac0f1c49 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1590,7 +1590,7 @@ def test_read_concern(self): next(c.find_raw_batches()) def test_monitoring(self): - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() @@ -1757,7 +1757,7 @@ def test_collation(self): next(self.db.test.aggregate_raw_batches([], collation=Collation("en_US"))) def test_monitoring(self): - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() diff --git a/test/test_grid_file.py b/test/test_grid_file.py index fe88aec5ff..c35efccef5 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import EventListener +from test.utils import OvertCommandListener from bson.objectid import ObjectId from gridfs.errors import NoFile @@ -808,7 +808,7 @@ def test_survive_cursor_not_found(self): # Use 102 batches to cause a single getMore. chunk_size = 1024 data = b"d" * (102 * chunk_size) - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test with GridIn(db.fs, chunk_size=chunk_size) as infile: diff --git a/test/test_index_management.py b/test/test_index_management.py index ec1e363737..6ca726e2e0 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -27,7 +27,7 @@ from test import IntegrationTest, PyMongoTestCase, unittest from test.unified_format import generate_test_classes -from test.utils import AllowListEventListener, EventListener +from test.utils import AllowListEventListener, EventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel @@ -88,7 +88,7 @@ def setUpClass(cls) -> None: url = os.environ.get("MONGODB_URI") username = os.environ["DB_USER"] password = os.environ["DB_PASSWORD"] - cls.listener = listener = EventListener() + cls.listener = listener = OvertCommandListener() cls.client = cls.unmanaged_simple_client( url, username=username, password=password, event_listeners=[listener] ) diff --git a/test/test_monitoring.py b/test/test_monitoring.py index a0c520ed27..75fe5c987a 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -31,6 +31,7 @@ ) from test.utils import ( EventListener, + OvertCommandListener, wait_until, ) @@ -54,7 +55,7 @@ class TestCommandMonitoring(IntegrationTest): @client_context.require_connection def _setup_class(cls): super()._setup_class() - cls.listener = EventListener() + cls.listener = OvertCommandListener() cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False ) @@ -1100,11 +1101,13 @@ def test_first_batch_helper(self): @client_context.require_version_max(6, 1, 99) def test_sensitive_commands(self): - listeners = self.client._event_listeners + listener = EventListener() + client = self.rs_or_single_client(event_listeners=[listener]) + listeners = client._event_listeners - self.listener.reset() + listener.reset() cmd = SON([("getnonce", 1)]) - listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address, None) # type: ignore[arg-type] + listeners.publish_command_start(cmd, "pymongo_test", 12345, client.address, None) # type: ignore[arg-type] delta = datetime.timedelta(milliseconds=100) listeners.publish_command_success( delta, @@ -1115,15 +1118,15 @@ def test_sensitive_commands(self): None, database_name="pymongo_test", ) - started = self.listener.started_events[0] - succeeded = self.listener.succeeded_events[0] - self.assertEqual(0, len(self.listener.failed_events)) + started = listener.started_events[0] + succeeded = listener.succeeded_events[0] + self.assertEqual(0, len(listener.failed_events)) self.assertIsInstance(started, monitoring.CommandStartedEvent) self.assertEqual({}, started.command) self.assertEqual("pymongo_test", started.database_name) self.assertEqual("getnonce", started.command_name) self.assertIsInstance(started.request_id, int) - self.assertEqual(self.client.address, started.connection_id) + self.assertEqual(client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertEqual(succeeded.duration_micros, 100000) self.assertEqual(started.command_name, succeeded.command_name) @@ -1140,7 +1143,7 @@ class TestGlobalListener(IntegrationTest): @client_context.require_connection def _setup_class(cls): super()._setup_class() - cls.listener = EventListener() + cls.listener = OvertCommandListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 67943d495d..db53b67ae4 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -24,7 +24,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import EventListener +from test.utils import OvertCommandListener from pymongo import DESCENDING from pymongo.errors import ( @@ -44,7 +44,7 @@ class TestReadWriteConcernSpec(IntegrationTest): def test_omit_default_read_write_concern(self): - listener = EventListener() + listener = OvertCommandListener() # Client with default readConcern and writeConcern client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) @@ -205,7 +205,7 @@ def test_error_includes_errInfo(self): @client_context.require_version_min(4, 9) def test_write_error_details_exposes_errinfo(self): - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) db = client.errinfotest diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 67e9716bf4..984b967f50 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -33,6 +33,7 @@ from test.utils import ( EventListener, FunctionCallRecorder, + OvertCommandListener, wait_until, ) from test.utils_selection_tests import ( @@ -74,7 +75,7 @@ def custom_selector(servers): return [servers[idx]] # Initialize client with appropriate listeners. - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client( server_selector=custom_selector, event_listeners=[listener] ) diff --git a/test/test_session.py b/test/test_session.py index 9f94ded927..d0bbb075a8 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -36,6 +36,7 @@ from test.utils import ( EventListener, ExceptionCatchingThread, + OvertCommandListener, wait_until, ) @@ -198,7 +199,7 @@ def test_implicit_sessions_checkout(self): lsid_set = set() failures = 0 for _ in range(5): - listener = EventListener() + listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ diff --git a/test/test_ssl.py b/test/test_ssl.py index 36d7ba12b6..04db9b61a4 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -33,6 +33,7 @@ ) from test.utils import ( EventListener, + OvertCommandListener, cat_files, ignore_deprecations, )