diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index f2a2730e92..13127e358a 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -529,7 +529,10 @@ async def _run_async_impl( message_id=str(uuid.uuid4()), parts=message_parts, role="user", - context_id=context_id, + # Use existing context_id if available (for conversation continuity), + # otherwise use the local session ID to maintain session identity + # across local and remote agents. + context_id=context_id if context_id else ctx.session.id, ) logger.debug(build_a2a_request_log(a2a_request)) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 2195be83c5..8f29481351 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -1590,6 +1590,103 @@ async def test_run_async_impl_successful_request(self): in mock_event.custom_metadata ) + async def _run_context_id_test( + self, mock_context_id: str | None, expected_context_id: str + ): + """Helper to test context_id handling in _run_async_impl. + + Args: + mock_context_id: The context_id to return from + _construct_message_parts_from_session. + expected_context_id: The expected context_id in the A2AMessage. + """ + from a2a.client import Client as A2AClient + from a2a.types import TextPart + + with patch.object(self.agent, "_ensure_resolved"): + with patch.object( + self.agent, "_create_a2a_request_for_user_function_response" + ) as mock_create_func: + mock_create_func.return_value = None + + with patch.object( + self.agent, "_construct_message_parts_from_session" + ) as mock_construct: + mock_a2a_part = Mock(spec=TextPart) + mock_construct.return_value = ([mock_a2a_part], mock_context_id) + + # Mock A2A client + mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_response = Mock() + mock_send_message = AsyncMock() + mock_send_message.__aiter__.return_value = [mock_response] + mock_a2a_client.send_message.return_value = mock_send_message + self.agent._a2a_client = mock_a2a_client + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + + with patch.object(self.agent, "_handle_a2a_response") as mock_handle: + mock_handle.return_value = mock_event + + with patch( + "google.adk.agents.remote_a2a_agent.build_a2a_request_log" + ) as mock_req_log: + with patch( + "google.adk.agents.remote_a2a_agent.build_a2a_response_log" + ) as mock_resp_log: + mock_req_log.return_value = "Mock request log" + mock_resp_log.return_value = "Mock response log" + + with patch( + "google.adk.agents.remote_a2a_agent.A2AMessage" + ) as mock_message_class: + mock_message = Mock(spec=A2AMessage) + mock_message_class.return_value = mock_message + mock_response.model_dump.return_value = {"test": "response"} + + # Execute + events = [] + async for event in self.agent._run_async_impl( + self.mock_context + ): + events.append(event) + + # Verify A2AMessage was called with expected context_id + mock_message_class.assert_called_once() + call_kwargs = mock_message_class.call_args[1] + assert call_kwargs["context_id"] == expected_context_id + + @pytest.mark.asyncio + async def test_run_async_impl_uses_session_id_when_no_context_id(self): + """Test that session ID is used as context_id when no existing context. + + When _construct_message_parts_from_session returns None for context_id, + the agent should use ctx.session.id to maintain session identity across + local and remote agents. + """ + await self._run_context_id_test( + mock_context_id=None, + expected_context_id=self.mock_session.id, + ) + + @pytest.mark.asyncio + async def test_run_async_impl_preserves_existing_context_id(self): + """Test that existing context_id is preserved when available. + + When _construct_message_parts_from_session returns a context_id from + a previous remote agent response, that context_id should be used + for conversation continuity. + """ + existing_context_id = "existing-context-456" + await self._run_context_id_test( + mock_context_id=existing_context_id, + expected_context_id=existing_context_id, + ) + @pytest.mark.asyncio async def test_run_async_impl_a2a_client_error(self): """Test _run_async_impl when A2A send_message fails."""