From b2242ae38a07b61b95cc00f67c13eb823d793f66 Mon Sep 17 00:00:00 2001 From: Krishna Date: Fri, 26 Dec 2025 12:54:49 +0530 Subject: [PATCH 1/3] Update remote_a2a_agent.py --- src/google/adk/agents/remote_a2a_agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index f2a2730e92..13127e358a 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -529,7 +529,10 @@ async def _run_async_impl( message_id=str(uuid.uuid4()), parts=message_parts, role="user", - context_id=context_id, + # Use existing context_id if available (for conversation continuity), + # otherwise use the local session ID to maintain session identity + # across local and remote agents. + context_id=context_id if context_id else ctx.session.id, ) logger.debug(build_a2a_request_log(a2a_request)) From 8db2c6c617f54d63defd345e58a0982ffd4d118d Mon Sep 17 00:00:00 2001 From: Krishna Date: Fri, 26 Dec 2025 12:55:13 +0530 Subject: [PATCH 2/3] Update test_remote_a2a_agent.py --- .../unittests/agents/test_remote_a2a_agent.py | 154 ++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 2195be83c5..bb633b49c3 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -1590,6 +1590,160 @@ async def test_run_async_impl_successful_request(self): in mock_event.custom_metadata ) + @pytest.mark.asyncio + async def test_run_async_impl_uses_session_id_when_no_context_id(self): + """Test that session ID is used as context_id when no existing context. + + When _construct_message_parts_from_session returns None for context_id, + the agent should use ctx.session.id to maintain session identity across + local and remote agents. + """ + with patch.object(self.agent, "_ensure_resolved"): + with patch.object( + self.agent, "_create_a2a_request_for_user_function_response" + ) as mock_create_func: + mock_create_func.return_value = None + + with patch.object( + self.agent, "_construct_message_parts_from_session" + ) as mock_construct: + # Create proper A2A part mocks + from a2a.client import Client as A2AClient + from a2a.types import TextPart + + mock_a2a_part = Mock(spec=TextPart) + # Return None for context_id to trigger session ID fallback + mock_construct.return_value = ( + [mock_a2a_part], + None, + ) # Tuple with parts and NO context_id + + # Mock A2A client + mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_response = Mock() + mock_send_message = AsyncMock() + mock_send_message.__aiter__.return_value = [mock_response] + mock_a2a_client.send_message.return_value = mock_send_message + self.agent._a2a_client = mock_a2a_client + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + + with patch.object(self.agent, "_handle_a2a_response") as mock_handle: + mock_handle.return_value = mock_event + + # Mock the logging functions to avoid iteration issues + with patch( + "google.adk.agents.remote_a2a_agent.build_a2a_request_log" + ) as mock_req_log: + with patch( + "google.adk.agents.remote_a2a_agent.build_a2a_response_log" + ) as mock_resp_log: + mock_req_log.return_value = "Mock request log" + mock_resp_log.return_value = "Mock response log" + + # Mock the A2AMessage constructor to capture the arguments + with patch( + "google.adk.agents.remote_a2a_agent.A2AMessage" + ) as mock_message_class: + mock_message = Mock(spec=A2AMessage) + mock_message_class.return_value = mock_message + + # Add model_dump to mock_response for metadata + mock_response.model_dump.return_value = {"test": "response"} + + # Execute + events = [] + async for event in self.agent._run_async_impl( + self.mock_context + ): + events.append(event) + + # Verify A2AMessage was called with session ID as context_id + mock_message_class.assert_called_once() + call_kwargs = mock_message_class.call_args[1] + assert call_kwargs["context_id"] == self.mock_session.id + + @pytest.mark.asyncio + async def test_run_async_impl_preserves_existing_context_id(self): + """Test that existing context_id is preserved when available. + + When _construct_message_parts_from_session returns a context_id from + a previous remote agent response, that context_id should be used + for conversation continuity. + """ + with patch.object(self.agent, "_ensure_resolved"): + with patch.object( + self.agent, "_create_a2a_request_for_user_function_response" + ) as mock_create_func: + mock_create_func.return_value = None + + with patch.object( + self.agent, "_construct_message_parts_from_session" + ) as mock_construct: + # Create proper A2A part mocks + from a2a.client import Client as A2AClient + from a2a.types import TextPart + + mock_a2a_part = Mock(spec=TextPart) + existing_context_id = "existing-context-456" + mock_construct.return_value = ( + [mock_a2a_part], + existing_context_id, + ) # Tuple with parts and existing context_id + + # Mock A2A client + mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_response = Mock() + mock_send_message = AsyncMock() + mock_send_message.__aiter__.return_value = [mock_response] + mock_a2a_client.send_message.return_value = mock_send_message + self.agent._a2a_client = mock_a2a_client + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + + with patch.object(self.agent, "_handle_a2a_response") as mock_handle: + mock_handle.return_value = mock_event + + # Mock the logging functions to avoid iteration issues + with patch( + "google.adk.agents.remote_a2a_agent.build_a2a_request_log" + ) as mock_req_log: + with patch( + "google.adk.agents.remote_a2a_agent.build_a2a_response_log" + ) as mock_resp_log: + mock_req_log.return_value = "Mock request log" + mock_resp_log.return_value = "Mock response log" + + # Mock the A2AMessage constructor to capture the arguments + with patch( + "google.adk.agents.remote_a2a_agent.A2AMessage" + ) as mock_message_class: + mock_message = Mock(spec=A2AMessage) + mock_message_class.return_value = mock_message + + # Add model_dump to mock_response for metadata + mock_response.model_dump.return_value = {"test": "response"} + + # Execute + events = [] + async for event in self.agent._run_async_impl( + self.mock_context + ): + events.append(event) + + # Verify A2AMessage was called with existing context_id + mock_message_class.assert_called_once() + call_kwargs = mock_message_class.call_args[1] + assert call_kwargs["context_id"] == existing_context_id + @pytest.mark.asyncio async def test_run_async_impl_a2a_client_error(self): """Test _run_async_impl when A2A send_message fails.""" From 5c1106d79883637884979ac6b32cf96734f1cefc Mon Sep 17 00:00:00 2001 From: Krishna Date: Fri, 26 Dec 2025 13:06:47 +0530 Subject: [PATCH 3/3] Update test_remote_a2a_agent.py --- .../unittests/agents/test_remote_a2a_agent.py | 121 +++++------------- 1 file changed, 32 insertions(+), 89 deletions(-) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index bb633b49c3..8f29481351 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -1590,14 +1590,19 @@ async def test_run_async_impl_successful_request(self): in mock_event.custom_metadata ) - @pytest.mark.asyncio - async def test_run_async_impl_uses_session_id_when_no_context_id(self): - """Test that session ID is used as context_id when no existing context. + async def _run_context_id_test( + self, mock_context_id: str | None, expected_context_id: str + ): + """Helper to test context_id handling in _run_async_impl. - When _construct_message_parts_from_session returns None for context_id, - the agent should use ctx.session.id to maintain session identity across - local and remote agents. + Args: + mock_context_id: The context_id to return from + _construct_message_parts_from_session. + expected_context_id: The expected context_id in the A2AMessage. """ + from a2a.client import Client as A2AClient + from a2a.types import TextPart + with patch.object(self.agent, "_ensure_resolved"): with patch.object( self.agent, "_create_a2a_request_for_user_function_response" @@ -1607,16 +1612,8 @@ async def test_run_async_impl_uses_session_id_when_no_context_id(self): with patch.object( self.agent, "_construct_message_parts_from_session" ) as mock_construct: - # Create proper A2A part mocks - from a2a.client import Client as A2AClient - from a2a.types import TextPart - mock_a2a_part = Mock(spec=TextPart) - # Return None for context_id to trigger session ID fallback - mock_construct.return_value = ( - [mock_a2a_part], - None, - ) # Tuple with parts and NO context_id + mock_construct.return_value = ([mock_a2a_part], mock_context_id) # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) @@ -1635,7 +1632,6 @@ async def test_run_async_impl_uses_session_id_when_no_context_id(self): with patch.object(self.agent, "_handle_a2a_response") as mock_handle: mock_handle.return_value = mock_event - # Mock the logging functions to avoid iteration issues with patch( "google.adk.agents.remote_a2a_agent.build_a2a_request_log" ) as mock_req_log: @@ -1645,14 +1641,11 @@ async def test_run_async_impl_uses_session_id_when_no_context_id(self): mock_req_log.return_value = "Mock request log" mock_resp_log.return_value = "Mock response log" - # Mock the A2AMessage constructor to capture the arguments with patch( "google.adk.agents.remote_a2a_agent.A2AMessage" ) as mock_message_class: mock_message = Mock(spec=A2AMessage) mock_message_class.return_value = mock_message - - # Add model_dump to mock_response for metadata mock_response.model_dump.return_value = {"test": "response"} # Execute @@ -1662,10 +1655,23 @@ async def test_run_async_impl_uses_session_id_when_no_context_id(self): ): events.append(event) - # Verify A2AMessage was called with session ID as context_id + # Verify A2AMessage was called with expected context_id mock_message_class.assert_called_once() call_kwargs = mock_message_class.call_args[1] - assert call_kwargs["context_id"] == self.mock_session.id + assert call_kwargs["context_id"] == expected_context_id + + @pytest.mark.asyncio + async def test_run_async_impl_uses_session_id_when_no_context_id(self): + """Test that session ID is used as context_id when no existing context. + + When _construct_message_parts_from_session returns None for context_id, + the agent should use ctx.session.id to maintain session identity across + local and remote agents. + """ + await self._run_context_id_test( + mock_context_id=None, + expected_context_id=self.mock_session.id, + ) @pytest.mark.asyncio async def test_run_async_impl_preserves_existing_context_id(self): @@ -1675,74 +1681,11 @@ async def test_run_async_impl_preserves_existing_context_id(self): a previous remote agent response, that context_id should be used for conversation continuity. """ - with patch.object(self.agent, "_ensure_resolved"): - with patch.object( - self.agent, "_create_a2a_request_for_user_function_response" - ) as mock_create_func: - mock_create_func.return_value = None - - with patch.object( - self.agent, "_construct_message_parts_from_session" - ) as mock_construct: - # Create proper A2A part mocks - from a2a.client import Client as A2AClient - from a2a.types import TextPart - - mock_a2a_part = Mock(spec=TextPart) - existing_context_id = "existing-context-456" - mock_construct.return_value = ( - [mock_a2a_part], - existing_context_id, - ) # Tuple with parts and existing context_id - - # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() - mock_send_message = AsyncMock() - mock_send_message.__aiter__.return_value = [mock_response] - mock_a2a_client.send_message.return_value = mock_send_message - self.agent._a2a_client = mock_a2a_client - - mock_event = Event( - author=self.agent.name, - invocation_id=self.mock_context.invocation_id, - branch=self.mock_context.branch, - ) - - with patch.object(self.agent, "_handle_a2a_response") as mock_handle: - mock_handle.return_value = mock_event - - # Mock the logging functions to avoid iteration issues - with patch( - "google.adk.agents.remote_a2a_agent.build_a2a_request_log" - ) as mock_req_log: - with patch( - "google.adk.agents.remote_a2a_agent.build_a2a_response_log" - ) as mock_resp_log: - mock_req_log.return_value = "Mock request log" - mock_resp_log.return_value = "Mock response log" - - # Mock the A2AMessage constructor to capture the arguments - with patch( - "google.adk.agents.remote_a2a_agent.A2AMessage" - ) as mock_message_class: - mock_message = Mock(spec=A2AMessage) - mock_message_class.return_value = mock_message - - # Add model_dump to mock_response for metadata - mock_response.model_dump.return_value = {"test": "response"} - - # Execute - events = [] - async for event in self.agent._run_async_impl( - self.mock_context - ): - events.append(event) - - # Verify A2AMessage was called with existing context_id - mock_message_class.assert_called_once() - call_kwargs = mock_message_class.call_args[1] - assert call_kwargs["context_id"] == existing_context_id + existing_context_id = "existing-context-456" + await self._run_context_id_test( + mock_context_id=existing_context_id, + expected_context_id=existing_context_id, + ) @pytest.mark.asyncio async def test_run_async_impl_a2a_client_error(self):