From f7bb39e4755eacdcf156aa4f58abf256e8c3f034 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 14 Jan 2025 16:04:36 -0500 Subject: [PATCH 1/9] PYTHON-5021 - Fix usages of getaddrinfo to be non-blocking --- pymongo/asynchronous/auth.py | 24 +++++++++++++++++++----- pymongo/asynchronous/pool.py | 13 ++++++++++--- pymongo/synchronous/auth.py | 20 +++++++++++++++++--- pymongo/synchronous/pool.py | 9 ++++++++- test/asynchronous/test_auth.py | 4 ++-- 5 files changed, 56 insertions(+), 14 deletions(-) diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index 48ce4bbd39..748c199e9a 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -15,6 +15,7 @@ """Authentication helpers.""" from __future__ import annotations +import asyncio import functools import hashlib import hmac @@ -177,15 +178,28 @@ def _auth_key(nonce: str, username: str, password: str) -> str: return md5hash.hexdigest() -def _canonicalize_hostname(hostname: str, option: str | bool) -> str: +async def _canonicalize_hostname(hostname: str, option: str | bool) -> str: """Canonicalize hostname following MIT-krb5 behavior.""" # https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520 if option in [False, "none"]: return hostname - af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME - )[0] + if not _IS_SYNC: + loop = asyncio.get_event_loop() + af, socktype, proto, canonname, sockaddr = ( + await loop.getaddrinfo( + hostname, + None, + family=0, + type=0, + proto=socket.IPPROTO_TCP, + flags=socket.AI_CANONNAME, + ) + )[0] # type: ignore[index] + else: + af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( + hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME + )[0] # For forward just to resolve the cname as dns.lookup() will not return it. if option == "forward": @@ -213,7 +227,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti # Starting here and continuing through the while loop below - establish # the security context. See RFC 4752, Section 3.1, first paragraph. host = props.service_host or conn.address[0] - host = _canonicalize_hostname(host, props.canonicalize_host_name) + host = await _canonicalize_hostname(host, props.canonicalize_host_name) service = props.service_name + "@" + host if props.service_realm is not None: service = service + "@" + props.service_realm diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 5dc5675a0a..943c7a9ff4 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -783,7 +783,7 @@ def __repr__(self) -> str: ) -def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: +async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: """Given (host, port) and PoolOptions, connect and return a socket object. Can raise socket.error. @@ -814,7 +814,14 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket family = socket.AF_UNSPEC err = None - for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + if not _IS_SYNC: + loop = asyncio.get_event_loop() + results = await loop.getaddrinfo( # type: ignore[assignment] + host, port, family=family, type=socket.SOCK_STREAM + ) + else: + results = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) # type: ignore[assignment] + for res in results: # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 @@ -863,7 +870,7 @@ async def _configured_socket( Sets socket's SSL and timeout options. """ - sock = _create_connection(address, options) + sock = await _create_connection(address, options) ssl_context = options._ssl_context if ssl_context is None: diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index 0e51ff8b7f..c48bd276ba 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -15,6 +15,7 @@ """Authentication helpers.""" from __future__ import annotations +import asyncio import functools import hashlib import hmac @@ -180,9 +181,22 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str: if option in [False, "none"]: return hostname - af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME - )[0] + if not _IS_SYNC: + loop = asyncio.get_event_loop() + af, socktype, proto, canonname, sockaddr = ( + loop.getaddrinfo( + hostname, + None, + family=0, + type=0, + proto=socket.IPPROTO_TCP, + flags=socket.AI_CANONNAME, + ) + )[0] # type: ignore[index] + else: + af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( + hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME + )[0] # For forward just to resolve the cname as dns.lookup() will not return it. if option == "forward": diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1a155c82d7..16465f4f77 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -812,7 +812,14 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket family = socket.AF_UNSPEC err = None - for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + if not _IS_SYNC: + loop = asyncio.get_event_loop() + results = loop.getaddrinfo( # type: ignore[assignment] + host, port, family=family, type=socket.SOCK_STREAM + ) + else: + results = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) # type: ignore[assignment] + for res in results: # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 08dc4d7247..7172152d69 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -275,10 +275,10 @@ async def test_gssapi_threaded(self): async def test_gssapi_canonicalize_host_name(self): # Test the low level method. assert GSSAPI_HOST is not None - result = _canonicalize_hostname(GSSAPI_HOST, "forward") + result = await _canonicalize_hostname(GSSAPI_HOST, "forward") if "compute-1.amazonaws.com" not in result: self.assertEqual(result, GSSAPI_HOST) - result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse") + result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse") self.assertEqual(result, GSSAPI_HOST) # Use the equivalent named CANONICALIZE_HOST_NAME. From f6c01364439c45b8ae28c5ba78d066ec5d905dc2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 09:31:39 -0500 Subject: [PATCH 2/9] Use get_running_loop instead of get_event_loop --- pymongo/asynchronous/auth.py | 2 +- pymongo/asynchronous/pool.py | 2 +- pymongo/synchronous/auth.py | 2 +- pymongo/synchronous/pool.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index 748c199e9a..24b8d592e2 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -185,7 +185,7 @@ async def _canonicalize_hostname(hostname: str, option: str | bool) -> str: return hostname if not _IS_SYNC: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() af, socktype, proto, canonname, sockaddr = ( await loop.getaddrinfo( hostname, diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 943c7a9ff4..174a319601 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -815,7 +815,7 @@ async def _create_connection(address: _Address, options: PoolOptions) -> socket. err = None if not _IS_SYNC: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() results = await loop.getaddrinfo( # type: ignore[assignment] host, port, family=family, type=socket.SOCK_STREAM ) diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index c48bd276ba..4ebf606e8a 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -182,7 +182,7 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str: return hostname if not _IS_SYNC: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() af, socktype, proto, canonname, sockaddr = ( loop.getaddrinfo( hostname, diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 16465f4f77..59773c7c75 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -813,7 +813,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket err = None if not _IS_SYNC: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() results = loop.getaddrinfo( # type: ignore[assignment] host, port, family=family, type=socket.SOCK_STREAM ) From 076a014262eb7a9db52eb8662efca8477db25af3 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 11:13:36 -0500 Subject: [PATCH 3/9] Use our own executor --- pymongo/__init__.py | 4 ++++ pymongo/asynchronous/pool.py | 8 +++++--- pymongo/pyopenssl_context.py | 5 +++-- pymongo/synchronous/pool.py | 8 +++++--- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 58f6ff338b..8b4f1fa50f 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -15,6 +15,7 @@ """Python driver for MongoDB.""" from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from typing import ContextManager, Optional __all__ = [ @@ -166,3 +167,6 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]: if seconds is not None: seconds = float(seconds) return _csot._TimeoutContext(seconds) + + +_PYMONGO_EXECUTOR = ThreadPoolExecutor(thread_name_prefix="PYMONGO_EXECUTOR-") diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 174a319601..45160e2ad7 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -38,7 +38,7 @@ ) from bson import DEFAULT_CODEC_OPTIONS -from pymongo import _csot, helpers_shared +from pymongo import _PYMONGO_EXECUTOR, _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth from pymongo.asynchronous.network import command, receive_message @@ -890,7 +890,7 @@ async def _configured_socket( else: loop = asyncio.get_running_loop() ssl_sock = await loop.run_in_executor( - None, + _PYMONGO_EXECUTOR, functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] ) else: @@ -901,7 +901,9 @@ async def _configured_socket( ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] else: loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] + ssl_sock = await loop.run_in_executor( + _PYMONGO_EXECUTOR, ssl_context.wrap_socket, sock + ) # type: ignore[assignment, misc] except _CertificateError: sock.close() # Raise _CertificateError directly like we do after match_hostname diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 8c643394b2..a60ee9e0bf 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -31,6 +31,7 @@ from OpenSSL import SSL as _SSL from OpenSSL import crypto as _crypto +from pymongo import _PYMONGO_EXECUTOR from pymongo.errors import ConfigurationError as _ConfigurationError from pymongo.errors import _CertificateError # type:ignore[attr-defined] from pymongo.ocsp_cache import _OCSPCache @@ -405,7 +406,7 @@ async def a_wrap_socket( ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) if self.verify_mode != _stdlibssl.CERT_NONE: # Request a stapled OCSP response. - await loop.run_in_executor(None, ssl_conn.request_ocsp) + await loop.run_in_executor(_PYMONGO_EXECUTOR, ssl_conn.request_ocsp) ssl_conn.set_connect_state() # If this wasn't true the caller of wrap_socket would call # do_handshake() @@ -413,7 +414,7 @@ async def a_wrap_socket( # XXX: If we do hostname checking in a callback we can get rid # of this call to do_handshake() since the handshake # will happen automatically later. - await loop.run_in_executor(None, ssl_conn.do_handshake) + await loop.run_in_executor(_PYMONGO_EXECUTOR, ssl_conn.do_handshake) # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. if self.check_hostname and server_hostname is not None: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 59773c7c75..c0739058e6 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -38,7 +38,7 @@ ) from bson import DEFAULT_CODEC_OPTIONS -from pymongo import _csot, helpers_shared +from pymongo import _PYMONGO_EXECUTOR, _csot, helpers_shared from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, @@ -886,7 +886,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. else: loop = asyncio.get_running_loop() ssl_sock = loop.run_in_executor( - None, + _PYMONGO_EXECUTOR, functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] ) else: @@ -897,7 +897,9 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] else: loop = asyncio.get_running_loop() - ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] + ssl_sock = loop.run_in_executor( + _PYMONGO_EXECUTOR, ssl_context.wrap_socket, sock + ) # type: ignore[assignment, misc] except _CertificateError: sock.close() # Raise _CertificateError directly like we do after match_hostname From 28215668550cdccefed67a58e0544842ec62e37c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 11:32:08 -0500 Subject: [PATCH 4/9] Fix import of _PYMONGO_EXECUTOR --- pymongo/_asyncio_executor.py | 22 ++++++++++++++++++++++ pymongo/asynchronous/pool.py | 3 ++- pymongo/pyopenssl_context.py | 2 +- pymongo/synchronous/pool.py | 3 ++- 4 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 pymongo/_asyncio_executor.py diff --git a/pymongo/_asyncio_executor.py b/pymongo/_asyncio_executor.py new file mode 100644 index 0000000000..189906a55c --- /dev/null +++ b/pymongo/_asyncio_executor.py @@ -0,0 +1,22 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A separate ThreadPoolExecutor instance used internally to avoid competing for resources with the default asyncio ThreadPoolExecutor +that user code will use.""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +_PYMONGO_EXECUTOR = ThreadPoolExecutor(thread_name_prefix="PYMONGO_EXECUTOR-") diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 45160e2ad7..f8dac8cd7b 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -38,7 +38,8 @@ ) from bson import DEFAULT_CODEC_OPTIONS -from pymongo import _PYMONGO_EXECUTOR, _csot, helpers_shared +from pymongo import _csot, helpers_shared +from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.asynchronous.client_session import _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth from pymongo.asynchronous.network import command, receive_message diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index a60ee9e0bf..038ef8df13 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -31,7 +31,7 @@ from OpenSSL import SSL as _SSL from OpenSSL import crypto as _crypto -from pymongo import _PYMONGO_EXECUTOR +from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.errors import ConfigurationError as _ConfigurationError from pymongo.errors import _CertificateError # type:ignore[attr-defined] from pymongo.ocsp_cache import _OCSPCache diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index c0739058e6..e3392b3f6e 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -38,7 +38,8 @@ ) from bson import DEFAULT_CODEC_OPTIONS -from pymongo import _PYMONGO_EXECUTOR, _csot, helpers_shared +from pymongo import _csot, helpers_shared +from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, From 297bf9cec86cce71d87333c3120ff5b38e1c677a Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 11:45:22 -0500 Subject: [PATCH 5/9] getaddrinfo helper method --- pymongo/asynchronous/auth.py | 28 +++++++++++----------------- pymongo/asynchronous/helpers.py | 12 ++++++++++++ pymongo/asynchronous/pool.py | 11 ++--------- pymongo/synchronous/auth.py | 28 +++++++++++----------------- pymongo/synchronous/helpers.py | 12 ++++++++++++ pymongo/synchronous/pool.py | 11 ++--------- 6 files changed, 50 insertions(+), 52 deletions(-) diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index 24b8d592e2..fbabdf66b3 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -15,7 +15,6 @@ """Authentication helpers.""" from __future__ import annotations -import asyncio import functools import hashlib import hmac @@ -39,6 +38,7 @@ _authenticate_oidc, _get_authenticator, ) +from pymongo.asynchronous.helpers import getaddrinfo from pymongo.auth_shared import ( MongoCredential, _authenticate_scram_start, @@ -184,22 +184,16 @@ async def _canonicalize_hostname(hostname: str, option: str | bool) -> str: if option in [False, "none"]: return hostname - if not _IS_SYNC: - loop = asyncio.get_running_loop() - af, socktype, proto, canonname, sockaddr = ( - await loop.getaddrinfo( - hostname, - None, - family=0, - type=0, - proto=socket.IPPROTO_TCP, - flags=socket.AI_CANONNAME, - ) - )[0] # type: ignore[index] - else: - af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME - )[0] + af, socktype, proto, canonname, sockaddr = ( + await getaddrinfo( + hostname, + None, + family=0, + type=0, + proto=socket.IPPROTO_TCP, + flags=socket.AI_CANONNAME, + ) + )[0] # type: ignore[index] # For forward just to resolve the cname as dns.lookup() will not return it. if option == "forward": diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 1ac8b6630f..b591f0a4da 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -15,7 +15,9 @@ """Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations +import asyncio import builtins +import socket import sys from typing import ( Any, @@ -68,6 +70,16 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +async def getaddrinfo(host, port, **kwargs): + if not _IS_SYNC: + loop = asyncio.get_running_loop() + return await loop.getaddrinfo( # type: ignore[assignment] + host, port, **kwargs + ) + else: + return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment] + + if sys.version_info >= (3, 10): anext = builtins.anext aiter = builtins.aiter diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index f8dac8cd7b..9683feff7f 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -41,7 +41,7 @@ from pymongo import _csot, helpers_shared from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.asynchronous.helpers import _handle_reauth, getaddrinfo from pymongo.asynchronous.network import command, receive_message from pymongo.common import ( MAX_BSON_SIZE, @@ -815,14 +815,7 @@ async def _create_connection(address: _Address, options: PoolOptions) -> socket. family = socket.AF_UNSPEC err = None - if not _IS_SYNC: - loop = asyncio.get_running_loop() - results = await loop.getaddrinfo( # type: ignore[assignment] - host, port, family=family, type=socket.SOCK_STREAM - ) - else: - results = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) # type: ignore[assignment] - for res in results: # type: ignore[attr-defined] + for res in await getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index 4ebf606e8a..016feee49f 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -15,7 +15,6 @@ """Authentication helpers.""" from __future__ import annotations -import asyncio import functools import hashlib import hmac @@ -46,6 +45,7 @@ _authenticate_oidc, _get_authenticator, ) +from pymongo.synchronous.helpers import getaddrinfo if TYPE_CHECKING: from pymongo.hello import Hello @@ -181,22 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str: if option in [False, "none"]: return hostname - if not _IS_SYNC: - loop = asyncio.get_running_loop() - af, socktype, proto, canonname, sockaddr = ( - loop.getaddrinfo( - hostname, - None, - family=0, - type=0, - proto=socket.IPPROTO_TCP, - flags=socket.AI_CANONNAME, - ) - )[0] # type: ignore[index] - else: - af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME - )[0] + af, socktype, proto, canonname, sockaddr = ( + getaddrinfo( + hostname, + None, + family=0, + type=0, + proto=socket.IPPROTO_TCP, + flags=socket.AI_CANONNAME, + ) + )[0] # type: ignore[index] # For forward just to resolve the cname as dns.lookup() will not return it. if option == "forward": diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 064583dad3..b65ce2fdb0 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -15,7 +15,9 @@ """Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations +import asyncio import builtins +import socket import sys from typing import ( Any, @@ -68,6 +70,16 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +def getaddrinfo(host, port, **kwargs): + if not _IS_SYNC: + loop = asyncio.get_running_loop() + return loop.getaddrinfo( # type: ignore[assignment] + host, port, **kwargs + ) + else: + return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment] + + if sys.version_info >= (3, 10): next = builtins.next iter = builtins.iter diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index e3392b3f6e..b4a310c903 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -85,7 +85,7 @@ from pymongo.socket_checker import SocketChecker from pymongo.ssl_support import HAS_SNI, SSLError from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.helpers import _handle_reauth +from pymongo.synchronous.helpers import _handle_reauth, getaddrinfo from pymongo.synchronous.network import command, receive_message if TYPE_CHECKING: @@ -813,14 +813,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket family = socket.AF_UNSPEC err = None - if not _IS_SYNC: - loop = asyncio.get_running_loop() - results = loop.getaddrinfo( # type: ignore[assignment] - host, port, family=family, type=socket.SOCK_STREAM - ) - else: - results = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) # type: ignore[assignment] - for res in results: # type: ignore[attr-defined] + for res in getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 From 39ade36b6f07c24e22ce52b314ba7a0ff975d1d7 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 11:47:07 -0500 Subject: [PATCH 6/9] cleanup --- pymongo/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 8b4f1fa50f..58f6ff338b 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -15,7 +15,6 @@ """Python driver for MongoDB.""" from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor from typing import ContextManager, Optional __all__ = [ @@ -167,6 +166,3 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]: if seconds is not None: seconds = float(seconds) return _csot._TimeoutContext(seconds) - - -_PYMONGO_EXECUTOR = ThreadPoolExecutor(thread_name_prefix="PYMONGO_EXECUTOR-") From 5e3bc65ab24aaed16b4dff6b0de46d3b8059f1a1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 12:08:14 -0500 Subject: [PATCH 7/9] Use run_in_executor for getaddrinfo --- pymongo/asynchronous/helpers.py | 20 ++++++++++++++++---- pymongo/synchronous/helpers.py | 20 ++++++++++++++++---- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index b591f0a4da..4fba407267 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -17,6 +17,7 @@ import asyncio import builtins +import functools import socket import sys from typing import ( @@ -26,6 +27,7 @@ cast, ) +from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.errors import ( OperationFailure, ) @@ -70,14 +72,24 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) -async def getaddrinfo(host, port, **kwargs): +async def getaddrinfo( + host: Any, port: Any, **kwargs: Any +) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: if not _IS_SYNC: loop = asyncio.get_running_loop() - return await loop.getaddrinfo( # type: ignore[assignment] - host, port, **kwargs + return await loop.run_in_executor( # type: ignore[return-value] + _PYMONGO_EXECUTOR, functools.partial(socket.getaddrinfo, host, port, **kwargs) ) else: - return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment] + return socket.getaddrinfo(host, port, **kwargs) if sys.version_info >= (3, 10): diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index b65ce2fdb0..2158886c80 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -17,6 +17,7 @@ import asyncio import builtins +import functools import socket import sys from typing import ( @@ -26,6 +27,7 @@ cast, ) +from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.errors import ( OperationFailure, ) @@ -70,14 +72,24 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) -def getaddrinfo(host, port, **kwargs): +def getaddrinfo( + host: Any, port: Any, **kwargs: Any +) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: if not _IS_SYNC: loop = asyncio.get_running_loop() - return loop.getaddrinfo( # type: ignore[assignment] - host, port, **kwargs + return loop.run_in_executor( # type: ignore[return-value] + _PYMONGO_EXECUTOR, functools.partial(socket.getaddrinfo, host, port, **kwargs) ) else: - return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment] + return socket.getaddrinfo(host, port, **kwargs) if sys.version_info >= (3, 10): From 85e59fd3afbbf67ea0a95bde691c0505daabf46c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 15 Jan 2025 16:53:03 -0500 Subject: [PATCH 8/9] Revert back to using the default executor --- pymongo/_asyncio_executor.py | 22 ---------------------- pymongo/asynchronous/helpers.py | 6 +----- pymongo/asynchronous/pool.py | 7 ++----- pymongo/pyopenssl_context.py | 5 ++--- pymongo/synchronous/helpers.py | 6 +----- pymongo/synchronous/pool.py | 7 ++----- 6 files changed, 8 insertions(+), 45 deletions(-) delete mode 100644 pymongo/_asyncio_executor.py diff --git a/pymongo/_asyncio_executor.py b/pymongo/_asyncio_executor.py deleted file mode 100644 index 189906a55c..0000000000 --- a/pymongo/_asyncio_executor.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A separate ThreadPoolExecutor instance used internally to avoid competing for resources with the default asyncio ThreadPoolExecutor -that user code will use.""" - -from __future__ import annotations - -from concurrent.futures import ThreadPoolExecutor - -_PYMONGO_EXECUTOR = ThreadPoolExecutor(thread_name_prefix="PYMONGO_EXECUTOR-") diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 4fba407267..e37e01bf85 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -17,7 +17,6 @@ import asyncio import builtins -import functools import socket import sys from typing import ( @@ -27,7 +26,6 @@ cast, ) -from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.errors import ( OperationFailure, ) @@ -85,9 +83,7 @@ async def getaddrinfo( ]: if not _IS_SYNC: loop = asyncio.get_running_loop() - return await loop.run_in_executor( # type: ignore[return-value] - _PYMONGO_EXECUTOR, functools.partial(socket.getaddrinfo, host, port, **kwargs) - ) + return await loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value] else: return socket.getaddrinfo(host, port, **kwargs) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 9683feff7f..7c653869d7 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -39,7 +39,6 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared -from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.asynchronous.client_session import _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth, getaddrinfo from pymongo.asynchronous.network import command, receive_message @@ -884,7 +883,7 @@ async def _configured_socket( else: loop = asyncio.get_running_loop() ssl_sock = await loop.run_in_executor( - _PYMONGO_EXECUTOR, + None, functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] ) else: @@ -895,9 +894,7 @@ async def _configured_socket( ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] else: loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor( - _PYMONGO_EXECUTOR, ssl_context.wrap_socket, sock - ) # type: ignore[assignment, misc] + ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] except _CertificateError: sock.close() # Raise _CertificateError directly like we do after match_hostname diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 038ef8df13..8c643394b2 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -31,7 +31,6 @@ from OpenSSL import SSL as _SSL from OpenSSL import crypto as _crypto -from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.errors import ConfigurationError as _ConfigurationError from pymongo.errors import _CertificateError # type:ignore[attr-defined] from pymongo.ocsp_cache import _OCSPCache @@ -406,7 +405,7 @@ async def a_wrap_socket( ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) if self.verify_mode != _stdlibssl.CERT_NONE: # Request a stapled OCSP response. - await loop.run_in_executor(_PYMONGO_EXECUTOR, ssl_conn.request_ocsp) + await loop.run_in_executor(None, ssl_conn.request_ocsp) ssl_conn.set_connect_state() # If this wasn't true the caller of wrap_socket would call # do_handshake() @@ -414,7 +413,7 @@ async def a_wrap_socket( # XXX: If we do hostname checking in a callback we can get rid # of this call to do_handshake() since the handshake # will happen automatically later. - await loop.run_in_executor(_PYMONGO_EXECUTOR, ssl_conn.do_handshake) + await loop.run_in_executor(None, ssl_conn.do_handshake) # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. if self.check_hostname and server_hostname is not None: diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 2158886c80..1889f13f72 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -17,7 +17,6 @@ import asyncio import builtins -import functools import socket import sys from typing import ( @@ -27,7 +26,6 @@ cast, ) -from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.errors import ( OperationFailure, ) @@ -85,9 +83,7 @@ def getaddrinfo( ]: if not _IS_SYNC: loop = asyncio.get_running_loop() - return loop.run_in_executor( # type: ignore[return-value] - _PYMONGO_EXECUTOR, functools.partial(socket.getaddrinfo, host, port, **kwargs) - ) + return loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value] else: return socket.getaddrinfo(host, port, **kwargs) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index b4a310c903..8e788df030 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -39,7 +39,6 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared -from pymongo._asyncio_executor import _PYMONGO_EXECUTOR from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, @@ -880,7 +879,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. else: loop = asyncio.get_running_loop() ssl_sock = loop.run_in_executor( - _PYMONGO_EXECUTOR, + None, functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] ) else: @@ -891,9 +890,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] else: loop = asyncio.get_running_loop() - ssl_sock = loop.run_in_executor( - _PYMONGO_EXECUTOR, ssl_context.wrap_socket, sock - ) # type: ignore[assignment, misc] + ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] except _CertificateError: sock.close() # Raise _CertificateError directly like we do after match_hostname From abe6b24db1806e9615eff802c58327302e6d5b8c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 16 Jan 2025 08:32:31 -0500 Subject: [PATCH 9/9] getaddrinfo -> _getaddrinfo --- pymongo/asynchronous/auth.py | 4 ++-- pymongo/asynchronous/helpers.py | 2 +- pymongo/asynchronous/pool.py | 4 ++-- pymongo/synchronous/auth.py | 4 ++-- pymongo/synchronous/helpers.py | 2 +- pymongo/synchronous/pool.py | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index fbabdf66b3..b1e6d0125b 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -38,7 +38,7 @@ _authenticate_oidc, _get_authenticator, ) -from pymongo.asynchronous.helpers import getaddrinfo +from pymongo.asynchronous.helpers import _getaddrinfo from pymongo.auth_shared import ( MongoCredential, _authenticate_scram_start, @@ -185,7 +185,7 @@ async def _canonicalize_hostname(hostname: str, option: str | bool) -> str: return hostname af, socktype, proto, canonname, sockaddr = ( - await getaddrinfo( + await _getaddrinfo( hostname, None, family=0, diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index e37e01bf85..d519e8749c 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -70,7 +70,7 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) -async def getaddrinfo( +async def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ tuple[ diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 7c653869d7..bf2f2b4946 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -40,7 +40,7 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.helpers import _handle_reauth, getaddrinfo +from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth from pymongo.asynchronous.network import command, receive_message from pymongo.common import ( MAX_BSON_SIZE, @@ -814,7 +814,7 @@ async def _create_connection(address: _Address, options: PoolOptions) -> socket. family = socket.AF_UNSPEC err = None - for res in await getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] + for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index 016feee49f..56860eff3b 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -45,7 +45,7 @@ _authenticate_oidc, _get_authenticator, ) -from pymongo.synchronous.helpers import getaddrinfo +from pymongo.synchronous.helpers import _getaddrinfo if TYPE_CHECKING: from pymongo.hello import Hello @@ -182,7 +182,7 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str: return hostname af, socktype, proto, canonname, sockaddr = ( - getaddrinfo( + _getaddrinfo( hostname, None, family=0, diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 1889f13f72..f800e7dcc8 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -70,7 +70,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) -def getaddrinfo( +def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ tuple[ diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 8e788df030..05f930d480 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -84,7 +84,7 @@ from pymongo.socket_checker import SocketChecker from pymongo.ssl_support import HAS_SNI, SSLError from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.helpers import _handle_reauth, getaddrinfo +from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth from pymongo.synchronous.network import command, receive_message if TYPE_CHECKING: @@ -812,7 +812,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket family = socket.AF_UNSPEC err = None - for res in getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] + for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4