diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py index bfa84db69d..521ee83868 100644 --- a/src/google/adk/flows/llm_flows/_code_execution.py +++ b/src/google/adk/flows/llm_flows/_code_execution.py @@ -50,6 +50,8 @@ logger = logging.getLogger('google_adk.' + __name__) +_AVAILABLE_FILE_PREFIX = 'Available file:' + @dataclasses.dataclass class DataFileUtil: @@ -206,7 +208,7 @@ async def _run_pre_processor( # memory. Meanwhile, mutate the inline data file to text part in session # history from all turns. all_input_files = _extract_and_replace_inline_files( - code_executor_context, llm_request + code_executor_context, llm_request, invocation_context ) # [Step 2] Run Explore_Df code on the data files from the current turn. We @@ -375,20 +377,42 @@ async def _run_post_processor( def _extract_and_replace_inline_files( code_executor_context: CodeExecutorContext, llm_request: LlmRequest, + invocation_context: InvocationContext, ) -> list[File]: - """Extracts and replaces inline files with file names in the LLM request.""" + """Extracts and replaces inline files with file names in the LLM request. + + This function modifies both `llm_request.contents` for the current request + and `invocation_context.session.events` to ensure the replacement of inline + data with placeholders persists across conversation turns. + + Args: + code_executor_context: Context containing code executor state. + llm_request: The LLM request to process. + invocation_context: Context containing session information. + + Returns: + List of extracted File objects. + """ all_input_files = code_executor_context.get_input_files() saved_file_names = set(f.name for f in all_input_files) - # [Step 1] Process input files from LlmRequest and cache them in CodeExecutor. + # Track which session events need to be updated + events_to_update = {} + + # Process input files from LlmRequest and cache them in CodeExecutor. for i in range(len(llm_request.contents)): content = llm_request.contents[i] # Only process the user message. - if content.role != 'user' and not content.parts: + if content.role != 'user' or not content.parts: continue for j in range(len(content.parts)): part = content.parts[j] + + # Skip if already processed (already a placeholder) + if part.text and _AVAILABLE_FILE_PREFIX in part.text: + continue + # Skip if the inline data is not supported. if ( not part.inline_data @@ -399,21 +423,76 @@ def _extract_and_replace_inline_files( # Replace the inline data file with a file name placeholder. mime_type = part.inline_data.mime_type file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension - llm_request.contents[i].parts[j] = types.Part( - text='\nAvailable file: `%s`\n' % file_name - ) + placeholder_text = f'\n{_AVAILABLE_FILE_PREFIX} `{file_name}`\n' + + # Store inline_data before replacing + inline_data_copy = part.inline_data + + # Replace in llm_request + llm_request.contents[i].parts[j] = types.Part(text=placeholder_text) + + # Find and update the corresponding session event + # to persist the replacement across turns + session = invocation_context.session + for event_idx, event in enumerate(session.events): + if ( + event.content + and event.content.role == 'user' + and len(event.content.parts) > j + ): + event_part = event.content.parts[j] + # Match by inline_data content (comparing mime_type, length, and data) + # Length check first for performance optimization + if ( + event_part.inline_data + and event_part.inline_data.mime_type == mime_type + and len(event_part.inline_data.data) == len(inline_data_copy.data) + and event_part.inline_data.data == inline_data_copy.data + ): + # Mark this event/part for update + if event_idx not in events_to_update: + events_to_update[event_idx] = {} + events_to_update[event_idx][j] = placeholder_text + break # Add the inline data as input file to the code executor context. file = File( name=file_name, content=CodeExecutionUtils.get_encoded_file_content( - part.inline_data.data + inline_data_copy.data ).decode(), mime_type=mime_type, ) if file_name not in saved_file_names: code_executor_context.add_input_files([file]) all_input_files.append(file) + saved_file_names.add(file_name) + + # Apply updates to session.events to persist across turns + session = invocation_context.session + for event_idx, parts_to_update in events_to_update.items(): + event = session.events[event_idx] + # Create new parts list with replacements + updated_parts = list(event.content.parts) + for part_idx, placeholder_text in parts_to_update.items(): + updated_parts[part_idx] = types.Part(text=placeholder_text) + + # Create new content with updated parts + updated_content = types.Content( + role=event.content.role, parts=updated_parts + ) + + # Update the event in session (modify in place) + # Event is a Pydantic model, use model_copy() instead of dataclasses.replace() + session.events[event_idx] = event.model_copy( + update={'content': updated_content} + ) + + logger.debug( + 'Replaced inline_data in session.events[%d] with placeholder: %s', + event_idx, + placeholder_text.strip(), + ) return all_input_files diff --git a/tests/unittests/flows/llm_flows/test_code_execution_persistence.py b/tests/unittests/flows/llm_flows/test_code_execution_persistence.py new file mode 100644 index 0000000000..4523fd6114 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_code_execution_persistence.py @@ -0,0 +1,247 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for optimize_data_file persistence across turns.""" + +import pytest +from typing import Dict +from google.adk.code_executors.base_code_executor import BaseCodeExecutor +from google.genai import types +from google.adk.code_executors.code_execution_utils import ( + CodeExecutionInput, + CodeExecutionResult, + File, +) +from google.adk.code_executors.code_executor_context import CodeExecutorContext +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from pydantic import Field +import copy + + +class MockCodeExecutor(BaseCodeExecutor): + """Mock executor for testing.""" + + # Define as Pydantic fields + injected_files: Dict[str, str] = Field(default_factory=dict) + execution_count: int = Field(default=0) + + def execute_code(self, invocation_context, code_input: CodeExecutionInput): + """Mock code execution.""" + self.execution_count += 1 + # Store files if they're new + for file in code_input.input_files: + if file.name not in self.injected_files: + self.injected_files[file.name] = file.content + + return CodeExecutionResult( + stdout=f"Executed: {len(code_input.input_files)} files available", + stderr="", + output_files=[] + ) + + +@pytest.mark.asyncio +async def test_inline_data_replacement_in_extract_function(): + """Test that _extract_and_replace_inline_files modifies both request and session.""" + from google.adk.flows.llm_flows._code_execution import _extract_and_replace_inline_files + + # Create CSV data + csv_data = b"name,value\nA,100\nB,200" + + # Create a mock session with events + session = Session( + id='test_session', + app_name='test_app', + user_id='test_user', + state={}, + events=[] + ) + + # Create user content with inline_data + user_content = types.Content( + role='user', + parts=[ + types.Part(text="Process this"), + types.Part(inline_data=types.Blob(mime_type='text/csv', data=csv_data)) + ] + ) + + # Add to session events + from google.adk.events.event import Event + user_event = Event( + invocation_id='test_inv', + author='user', + content=user_content, + ) + session.events.append(user_event) + + # Create LLM request + llm_request = LlmRequest(contents=[copy.deepcopy(user_content)]) + + # Create code executor context + code_executor_context = CodeExecutorContext(session.state) + + # Create mock invocation context + from unittest.mock import Mock + invocation_context = Mock() + invocation_context.session = session + + # Call the function we're testing + result_files = _extract_and_replace_inline_files( + code_executor_context, + llm_request, + invocation_context + ) + + # Check LLM request was modified + has_inline_in_request = any( + p.inline_data + for content in llm_request.contents + for p in content.parts + ) + has_placeholder_in_request = any( + 'Available file:' in (p.text or '') + for content in llm_request.contents + for p in content.parts + ) + + # Check session events were modified + user_events = [e for e in session.events if e.content and e.content.role == 'user'] + assert len(user_events) > 0, "Should have user events" + + first_user_event = user_events[0] + has_inline_in_session = any(p.inline_data for p in first_user_event.content.parts) + has_placeholder_in_session = any( + 'Available file:' in (p.text or '') + for p in first_user_event.content.parts + ) + + # Assertions for LLM request + assert not has_inline_in_request, "inline_data should be replaced in LLM request" + assert has_placeholder_in_request, "Placeholder should be present in LLM request" + + # Critical assertions for session events + assert not has_inline_in_session, "inline_data should be replaced in session.events" + assert has_placeholder_in_session, "Placeholder should be present in session.events" + + # Test that files were extracted + assert len(result_files) >= 1, "At least one file should be extracted" + + +@pytest.mark.asyncio +async def test_persistence_across_simulated_turns(): + """Test that on a second 'turn', inline_data doesn't reappear.""" + from google.adk.flows.llm_flows._code_execution import _extract_and_replace_inline_files + + # Create CSV data + csv_data = b"name,value\nA,100\nB,200" + + # Create a session + session = Session( + id='test_session', + app_name='test_app', + user_id='test_user', + state={}, + events=[] + ) + + # Turn 1: User sends CSV + user_content_1 = types.Content( + role='user', + parts=[ + types.Part(text="Process this"), + types.Part(inline_data=types.Blob(mime_type='text/csv', data=csv_data)) + ] + ) + + from google.adk.events.event import Event + user_event_1 = Event( + invocation_id='test_inv_1', + author='user', + content=user_content_1, + ) + session.events.append(user_event_1) + + # Create code executor context + code_executor_context = CodeExecutorContext(session.state) + + # Mock invocation context + from unittest.mock import Mock + invocation_context = Mock() + invocation_context.session = session + + # Process turn 1 + llm_request_1 = LlmRequest(contents=[copy.deepcopy(user_content_1)]) + files_1 = _extract_and_replace_inline_files( + code_executor_context, + llm_request_1, + invocation_context + ) + + initial_file_count = len(files_1) + + # Verify session was modified + user_events = [e for e in session.events if e.content and e.content.role == 'user'] + has_inline_after_turn1 = any( + p.inline_data + for e in user_events + for p in e.content.parts + ) + + # Turn 2: User sends follow-up + user_content_2 = types.Content( + role='user', + parts=[types.Part(text="What is the sum?")] + ) + + user_event_2 = Event( + invocation_id='test_inv_2', + author='user', + content=user_content_2, + ) + session.events.append(user_event_2) + + # Create new LLM request with ALL session events (simulating real flow) + all_contents = [] + for event in session.events: + if event.content: + all_contents.append(copy.deepcopy(event.content)) + + llm_request_2 = LlmRequest(contents=all_contents) + + # Check if inline_data reappeared + has_inline_before_process = any( + p.inline_data + for content in llm_request_2.contents + for p in content.parts + ) + + # Process turn 2 + files_2 = _extract_and_replace_inline_files( + code_executor_context, + llm_request_2, + invocation_context + ) + + final_file_count = len(code_executor_context.get_input_files()) + + # Critical assertion + assert not has_inline_before_process, \ + "inline_data should NOT reappear in turn 2 if session.events were properly modified" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])