diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index c807e291c..5a2897cc5 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -180,6 +180,7 @@ class BaseSession( _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] _response_routers: list["ResponseRouter"] + _closing: bool = False def __init__( self, @@ -252,6 +253,9 @@ async def send_request( Do not use this method to emit notifications! Use send_notification() instead. """ + if self._closing: + raise McpError(ErrorData(code=CONNECTION_CLOSED, message="Connection closed")) + request_id = self._request_id self._request_id = request_id + 1 @@ -307,7 +311,8 @@ async def send_request( return result_type.model_validate(response_or_error.result) finally: - self._response_streams.pop(request_id, None) + self._response_streams.pop(request_id, None) if not self._closing else None + self._progress_callbacks.pop(request_id, None) await response_stream.aclose() await response_stream_reader.aclose() @@ -444,15 +449,17 @@ async def _receive_loop(self) -> None: finally: # after the read stream is closed, we need to send errors # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: # pragma: no cover - # Stream might already be closed - pass - self._response_streams.clear() + self._closing = True + with anyio.CancelScope(shield=True): + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: # pragma: no cover + # Stream might already be closed + pass + self._response_streams.clear() def _normalize_request_id(self, response_id: RequestId) -> RequestId: """ @@ -508,7 +515,7 @@ async def _handle_response(self, message: SessionMessage) -> None: return # Handled # Fall back to normal response streams - stream = self._response_streams.pop(response_id, None) + stream = self._response_streams.pop(response_id, None) if not self._closing else None if stream: # pragma: no cover await stream.send(root) else: # pragma: no cover diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index b355a4bf2..3142270fc 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -337,3 +337,74 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): await ev_response.wait() + + +@pytest.mark.anyio +async def test_session_aexit_cleanup(): + """Test that the session is closing properly, cleaning up all resources.""" + pending_request_ids: list[int | str] = [] + requests_received = anyio.Event() + client_session_closed = anyio.Event() + + async with ( + anyio.create_task_group() as tg, + create_client_server_memory_streams() as (client_streams, server_streams), + ): + client_read, client_write = client_streams + server_read, _ = server_streams + + async def mock_server(): + """Block responses to simulate a server that does not respond.""" + # Wait for two ping requests + for _ in range(2): + message = await server_read.receive() + assert isinstance(message, SessionMessage) + root = message.message.root + assert isinstance(root, JSONRPCRequest) + assert root.method == "ping" + pending_request_ids.append(root.id) + + # Signal that both requests have been received + requests_received.set() + + # Wait for the client session to be closed + # This ensures the cleanup logic in finally block has time to run + await client_session_closed.wait() + + async def send_ping(session: ClientSession): + # Since we are closing the session, "Connection closed" McpError is expected + with pytest.raises(McpError) as e: + await session.send_ping() + assert "Connection closed" in str(e.value) + + # Start the mock server in the background + tg.start_soon(mock_server) + + # Create a session and send multiple ping requests in background + async with ClientSession(read_stream=client_read, write_stream=client_write) as session: + # Verify initial state + assert len(session._response_streams) == 0 + + # Start two ping requests in background + tg.start_soon(send_ping, session) + tg.start_soon(send_ping, session) + + # Wait for both requests to be sent and received by server + await requests_received.wait() + await anyio.sleep(0.1) # Give time for streams to be created + + # Verify we have 2 response streams + assert len(session._response_streams) == 2 + + # We close the session by escaping the async with block + client_session_closed.set() + + # Since the sesssion has been closed, "Connection closed" McpError is expected + with pytest.raises(McpError) as e: + await session.send_ping() + assert "Connection closed" in str(e.value) + + # Verify all response streams have been cleaned up + # (This happens when the async with block exits and __aexit__ is called) + assert session is not None + assert len(session._response_streams) == 0