From 0f93412e87789a96d8aed68c6dd805e0207b809a Mon Sep 17 00:00:00 2001 From: Ishan Raj Singh Date: Fri, 26 Dec 2025 14:25:24 +0530 Subject: [PATCH 1/2] fix: persist inline_data replacement across turns when optimize_data_file=True Problem: - When optimize_data_file=True, inline_data (CSV files) were replaced with text placeholders on the first turn - Subsequent turns restored original inline_data from session.events - Full CSV was sent to LLM on every turn, defeating the optimization Root Cause: - _extract_and_replace_inline_files() modified llm_request.contents (a copy) - session.events were never updated with the replacement - Each turn deep-copied unmodified session.events, restoring inline_data Solution: - Modified _extract_and_replace_inline_files() to accept invocation_context - Track which session events need updates via events_to_update dict - After replacing in llm_request, find and update matching session.events - Use event.model_copy(update={...}) for Pydantic model updates - Added early-exit check for already-processed placeholders Testing: - Added test_code_execution_persistence.py with two test cases: 1. test_inline_data_replacement_in_extract_function - verifies both llm_request and session.events are modified 2. test_persistence_across_simulated_turns - verifies inline_data doesn't reappear on subsequent turns - All tests pass - Verified files not re-injected on subsequent turns Impact: - Significantly reduces token usage for sessions with uploaded files - Lower LLM API costs - Better performance (smaller request payloads) - No breaking changes Fixes #4013 --- .../adk/flows/llm_flows/_code_execution.py | 83 ++++- .../test_code_execution_persistence.py | 291 ++++++++++++++++++ 2 files changed, 366 insertions(+), 8 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_code_execution_persistence.py diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py index bfa84db69d..2537a0d540 100644 --- a/src/google/adk/flows/llm_flows/_code_execution.py +++ b/src/google/adk/flows/llm_flows/_code_execution.py @@ -205,8 +205,9 @@ async def _run_pre_processor( # [Step 1] Extract data files from the session_history and store them in # memory. Meanwhile, mutate the inline data file to text part in session # history from all turns. + # CRITICAL FIX: Also pass the invocation_context to access session.events all_input_files = _extract_and_replace_inline_files( - code_executor_context, llm_request + code_executor_context, llm_request, invocation_context ) # [Step 2] Run Explore_Df code on the data files from the current turn. We @@ -375,20 +376,33 @@ async def _run_post_processor( def _extract_and_replace_inline_files( code_executor_context: CodeExecutorContext, llm_request: LlmRequest, + invocation_context: InvocationContext, ) -> list[File]: - """Extracts and replaces inline files with file names in the LLM request.""" + """Extracts and replaces inline files with file names in the LLM request. + + FIX: This function now modifies BOTH llm_request.contents AND + session.events to ensure inline_data replacement persists across turns. + """ all_input_files = code_executor_context.get_input_files() saved_file_names = set(f.name for f in all_input_files) - # [Step 1] Process input files from LlmRequest and cache them in CodeExecutor. + # Track which session events need to be updated + events_to_update = {} + + # Process input files from LlmRequest and cache them in CodeExecutor. for i in range(len(llm_request.contents)): content = llm_request.contents[i] # Only process the user message. - if content.role != 'user' and not content.parts: + if content.role != 'user' or not content.parts: continue for j in range(len(content.parts)): part = content.parts[j] + + # Skip if already processed (already a placeholder) + if part.text and 'Available file:' in part.text: + continue + # Skip if the inline data is not supported. if ( not part.inline_data @@ -399,21 +413,74 @@ def _extract_and_replace_inline_files( # Replace the inline data file with a file name placeholder. mime_type = part.inline_data.mime_type file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension - llm_request.contents[i].parts[j] = types.Part( - text='\nAvailable file: `%s`\n' % file_name - ) + placeholder_text = '\nAvailable file: `%s`\n' % file_name + + # Store inline_data before replacing + inline_data_copy = part.inline_data + + # Replace in llm_request + llm_request.contents[i].parts[j] = types.Part(text=placeholder_text) + + # Find and update the corresponding session event + # to persist the replacement across turns + session = invocation_context.session + for event_idx, event in enumerate(session.events): + if ( + event.content + and event.content.role == 'user' + and len(event.content.parts) > j + ): + event_part = event.content.parts[j] + # Match by inline_data content (comparing mime_type and data) + if ( + event_part.inline_data + and event_part.inline_data.mime_type == mime_type + and event_part.inline_data.data == inline_data_copy.data + ): + # Mark this event/part for update + if event_idx not in events_to_update: + events_to_update[event_idx] = {} + events_to_update[event_idx][j] = placeholder_text + break # Add the inline data as input file to the code executor context. file = File( name=file_name, content=CodeExecutionUtils.get_encoded_file_content( - part.inline_data.data + inline_data_copy.data ).decode(), mime_type=mime_type, ) if file_name not in saved_file_names: code_executor_context.add_input_files([file]) all_input_files.append(file) + saved_file_names.add(file_name) + + # Apply updates to session.events to persist across turns + session = invocation_context.session + for event_idx, parts_to_update in events_to_update.items(): + event = session.events[event_idx] + # Create new parts list with replacements + updated_parts = list(event.content.parts) + for part_idx, placeholder_text in parts_to_update.items(): + updated_parts[part_idx] = types.Part(text=placeholder_text) + + # Create new content with updated parts + updated_content = types.Content( + role=event.content.role, parts=updated_parts + ) + + # Update the event in session (modify in place) + # Event is a Pydantic model, use model_copy() instead of dataclasses.replace() + session.events[event_idx] = event.model_copy( + update={'content': updated_content} + ) + + logger.debug( + 'Replaced inline_data in session.events[%d] with placeholder: %s', + event_idx, + placeholder_text.strip(), + ) return all_input_files diff --git a/tests/unittests/flows/llm_flows/test_code_execution_persistence.py b/tests/unittests/flows/llm_flows/test_code_execution_persistence.py new file mode 100644 index 0000000000..d2f59fe4f1 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_code_execution_persistence.py @@ -0,0 +1,291 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for optimize_data_file persistence across turns.""" + +import pytest +from typing import Dict +from google.adk.code_executors.base_code_executor import BaseCodeExecutor +from google.genai import types +from google.adk.code_executors.code_execution_utils import ( + CodeExecutionInput, + CodeExecutionResult, + File, +) +from google.adk.code_executors.code_executor_context import CodeExecutorContext +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from pydantic import Field +import copy + + +class MockCodeExecutor(BaseCodeExecutor): + """Mock executor for testing.""" + + # Define as Pydantic fields + injected_files: Dict[str, str] = Field(default_factory=dict) + execution_count: int = Field(default=0) + + def execute_code(self, invocation_context, code_input: CodeExecutionInput): + """Mock code execution.""" + self.execution_count += 1 + # Store files if they're new + for file in code_input.input_files: + if file.name not in self.injected_files: + self.injected_files[file.name] = file.content + + return CodeExecutionResult( + stdout=f"Executed: {len(code_input.input_files)} files available", + stderr="", + output_files=[] + ) + + +@pytest.mark.asyncio +async def test_inline_data_replacement_in_extract_function(): + """Test that _extract_and_replace_inline_files modifies both request and session.""" + from google.adk.flows.llm_flows._code_execution import _extract_and_replace_inline_files + + # Create CSV data + csv_data = b"name,value\nA,100\nB,200" + + # Create a mock session with events + session = Session( + id='test_session', + app_name='test_app', + user_id='test_user', + state={}, + events=[] + ) + + # Create user content with inline_data + user_content = types.Content( + role='user', + parts=[ + types.Part(text="Process this"), + types.Part(inline_data=types.Blob(mime_type='text/csv', data=csv_data)) + ] + ) + + # Add to session events (simulating what happens in real flow) + from google.adk.events.event import Event + user_event = Event( + invocation_id='test_inv', + author='user', + content=user_content, + ) + session.events.append(user_event) + + # Create LLM request + llm_request = LlmRequest(contents=[copy.deepcopy(user_content)]) + + # Create code executor context + code_executor_context = CodeExecutorContext(session.state) + + # Create mock invocation context + from unittest.mock import Mock + invocation_context = Mock() + invocation_context.session = session + + # Call the function we're testing + print("\n=== BEFORE _extract_and_replace_inline_files ===") + print(f"LLM Request parts: {[type(p).__name__ for content in llm_request.contents for p in content.parts]}") + print(f"Session event parts: {[type(p).__name__ for e in session.events if e.content for p in e.content.parts]}") + + result_files = _extract_and_replace_inline_files( + code_executor_context, + llm_request, + invocation_context + ) + + print("\n=== AFTER _extract_and_replace_inline_files ===") + print(f"Files extracted: {len(result_files)}") + print(f"LLM Request parts: {[type(p).__name__ for content in llm_request.contents for p in content.parts]}") + print(f"Session event parts: {[type(p).__name__ for e in session.events if e.content for p in e.content.parts]}") + + # Check LLM request was modified + has_inline_in_request = any( + p.inline_data + for content in llm_request.contents + for p in content.parts + ) + has_placeholder_in_request = any( + 'Available file:' in (p.text or '') + for content in llm_request.contents + for p in content.parts + ) + + print(f"\nLLM Request:") + print(f" - Has inline_data: {has_inline_in_request}") + print(f" - Has placeholder: {has_placeholder_in_request}") + + for content in llm_request.contents: + for i, part in enumerate(content.parts): + if part.text: + print(f" - Part {i} text: {part.text[:50]}") + + # Check session events were modified (THIS IS THE KEY FIX) + user_events = [e for e in session.events if e.content and e.content.role == 'user'] + assert len(user_events) > 0, "Should have user events" + + first_user_event = user_events[0] + has_inline_in_session = any(p.inline_data for p in first_user_event.content.parts) + has_placeholder_in_session = any( + 'Available file:' in (p.text or '') + for p in first_user_event.content.parts + ) + + print(f"\nSession Events:") + print(f" - Has inline_data: {has_inline_in_session}") + print(f" - Has placeholder: {has_placeholder_in_session}") + + for i, part in enumerate(first_user_event.content.parts): + if part.text: + print(f" - Part {i} text: {part.text[:50]}") + if part.inline_data: + print(f" - Part {i} has inline_data of size: {len(part.inline_data.data)}") + + # Assertions for LLM request + assert not has_inline_in_request, "inline_data should be replaced in LLM request" + assert has_placeholder_in_request, "Placeholder should be present in LLM request" + + # Critical assertions for session events (THE FIX) + assert not has_inline_in_session, "inline_data should be replaced in session.events (FIX REQUIRED)" + assert has_placeholder_in_session, "Placeholder should be present in session.events (FIX REQUIRED)" + + # Test that files were extracted + assert len(result_files) >= 1, "At least one file should be extracted" + + print("\n=== TEST PASSED ===") + + +@pytest.mark.asyncio +async def test_persistence_across_simulated_turns(): + """Test that on a second 'turn', inline_data doesn't reappear.""" + from google.adk.flows.llm_flows._code_execution import _extract_and_replace_inline_files + + # Create CSV data + csv_data = b"name,value\nA,100\nB,200" + + # Create a session + session = Session( + id='test_session', + app_name='test_app', + user_id='test_user', + state={}, + events=[] + ) + + # Turn 1: User sends CSV + user_content_1 = types.Content( + role='user', + parts=[ + types.Part(text="Process this"), + types.Part(inline_data=types.Blob(mime_type='text/csv', data=csv_data)) + ] + ) + + from google.adk.events.event import Event + user_event_1 = Event( + invocation_id='test_inv_1', + author='user', + content=user_content_1, + ) + session.events.append(user_event_1) + + # Create code executor context + code_executor_context = CodeExecutorContext(session.state) + + # Mock invocation context + from unittest.mock import Mock + invocation_context = Mock() + invocation_context.session = session + + # Process turn 1 + llm_request_1 = LlmRequest(contents=[copy.deepcopy(user_content_1)]) + files_1 = _extract_and_replace_inline_files( + code_executor_context, + llm_request_1, + invocation_context + ) + + initial_file_count = len(files_1) + print(f"\nTurn 1: Extracted {initial_file_count} files") + + # Verify session was modified + user_events = [e for e in session.events if e.content and e.content.role == 'user'] + has_inline_after_turn1 = any( + p.inline_data + for e in user_events + for p in e.content.parts + ) + print(f"Turn 1: Session has inline_data after processing: {has_inline_after_turn1}") + + # Turn 2: User sends follow-up (simulating new LLM request from session history) + user_content_2 = types.Content( + role='user', + parts=[types.Part(text="What is the sum?")] + ) + + user_event_2 = Event( + invocation_id='test_inv_2', + author='user', + content=user_content_2, + ) + session.events.append(user_event_2) + + # Create new LLM request with ALL session events (this is what happens in real flow) + # Deep copy to simulate what base_llm_flow.py does + all_contents = [] + for event in session.events: + if event.content: + all_contents.append(copy.deepcopy(event.content)) + + llm_request_2 = LlmRequest(contents=all_contents) + + print(f"\nTurn 2: LLM request has {len(llm_request_2.contents)} contents") + + # Check if inline_data reappeared in the request + has_inline_before_process = any( + p.inline_data + for content in llm_request_2.contents + for p in content.parts + ) + print(f"Turn 2: LLM request has inline_data BEFORE processing: {has_inline_before_process}") + + # Process turn 2 + files_2 = _extract_and_replace_inline_files( + code_executor_context, + llm_request_2, + invocation_context + ) + + final_file_count = len(code_executor_context.get_input_files()) + print(f"Turn 2: Total files in context: {final_file_count}") + + # Critical assertion: if session.events were properly modified in turn 1, + # then turn 2 should NOT see inline_data (it should already be replaced) + if has_inline_before_process: + print("\n⚠️ FAIL: inline_data reappeared in turn 2 (FIX NOT WORKING)") + print("This means session.events were not properly modified in turn 1") + else: + print("\n✓ PASS: inline_data did not reappear in turn 2 (FIX WORKING)") + + assert not has_inline_before_process, \ + "inline_data should NOT reappear in turn 2 if session.events were properly modified" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 91ac2fd22485a7a9ac7be04904cc02eca05533be Mon Sep 17 00:00:00 2001 From: Ishan Raj Singh Date: Fri, 26 Dec 2025 15:23:33 +0530 Subject: [PATCH 2/2] address review feedback: improve code style and documentation - Remove development comments and make docstring more professional - Use f-strings instead of %-formatting for better readability - Define _AVAILABLE_FILE_PREFIX constant to avoid magic strings - Add length check before data comparison for performance optimization - Remove debug print() statements from tests - Improve code documentation and maintainability Changes requested by gemini-code-assist review. --- .../adk/flows/llm_flows/_code_execution.py | 24 +++++-- .../test_code_execution_persistence.py | 62 +++---------------- 2 files changed, 27 insertions(+), 59 deletions(-) diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py index 2537a0d540..521ee83868 100644 --- a/src/google/adk/flows/llm_flows/_code_execution.py +++ b/src/google/adk/flows/llm_flows/_code_execution.py @@ -50,6 +50,8 @@ logger = logging.getLogger('google_adk.' + __name__) +_AVAILABLE_FILE_PREFIX = 'Available file:' + @dataclasses.dataclass class DataFileUtil: @@ -205,7 +207,6 @@ async def _run_pre_processor( # [Step 1] Extract data files from the session_history and store them in # memory. Meanwhile, mutate the inline data file to text part in session # history from all turns. - # CRITICAL FIX: Also pass the invocation_context to access session.events all_input_files = _extract_and_replace_inline_files( code_executor_context, llm_request, invocation_context ) @@ -380,8 +381,17 @@ def _extract_and_replace_inline_files( ) -> list[File]: """Extracts and replaces inline files with file names in the LLM request. - FIX: This function now modifies BOTH llm_request.contents AND - session.events to ensure inline_data replacement persists across turns. + This function modifies both `llm_request.contents` for the current request + and `invocation_context.session.events` to ensure the replacement of inline + data with placeholders persists across conversation turns. + + Args: + code_executor_context: Context containing code executor state. + llm_request: The LLM request to process. + invocation_context: Context containing session information. + + Returns: + List of extracted File objects. """ all_input_files = code_executor_context.get_input_files() saved_file_names = set(f.name for f in all_input_files) @@ -400,7 +410,7 @@ def _extract_and_replace_inline_files( part = content.parts[j] # Skip if already processed (already a placeholder) - if part.text and 'Available file:' in part.text: + if part.text and _AVAILABLE_FILE_PREFIX in part.text: continue # Skip if the inline data is not supported. @@ -413,7 +423,7 @@ def _extract_and_replace_inline_files( # Replace the inline data file with a file name placeholder. mime_type = part.inline_data.mime_type file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension - placeholder_text = '\nAvailable file: `%s`\n' % file_name + placeholder_text = f'\n{_AVAILABLE_FILE_PREFIX} `{file_name}`\n' # Store inline_data before replacing inline_data_copy = part.inline_data @@ -431,10 +441,12 @@ def _extract_and_replace_inline_files( and len(event.content.parts) > j ): event_part = event.content.parts[j] - # Match by inline_data content (comparing mime_type and data) + # Match by inline_data content (comparing mime_type, length, and data) + # Length check first for performance optimization if ( event_part.inline_data and event_part.inline_data.mime_type == mime_type + and len(event_part.inline_data.data) == len(inline_data_copy.data) and event_part.inline_data.data == inline_data_copy.data ): # Mark this event/part for update diff --git a/tests/unittests/flows/llm_flows/test_code_execution_persistence.py b/tests/unittests/flows/llm_flows/test_code_execution_persistence.py index d2f59fe4f1..4523fd6114 100644 --- a/tests/unittests/flows/llm_flows/test_code_execution_persistence.py +++ b/tests/unittests/flows/llm_flows/test_code_execution_persistence.py @@ -79,7 +79,7 @@ async def test_inline_data_replacement_in_extract_function(): ] ) - # Add to session events (simulating what happens in real flow) + # Add to session events from google.adk.events.event import Event user_event = Event( invocation_id='test_inv', @@ -100,21 +100,12 @@ async def test_inline_data_replacement_in_extract_function(): invocation_context.session = session # Call the function we're testing - print("\n=== BEFORE _extract_and_replace_inline_files ===") - print(f"LLM Request parts: {[type(p).__name__ for content in llm_request.contents for p in content.parts]}") - print(f"Session event parts: {[type(p).__name__ for e in session.events if e.content for p in e.content.parts]}") - result_files = _extract_and_replace_inline_files( code_executor_context, llm_request, invocation_context ) - print("\n=== AFTER _extract_and_replace_inline_files ===") - print(f"Files extracted: {len(result_files)}") - print(f"LLM Request parts: {[type(p).__name__ for content in llm_request.contents for p in content.parts]}") - print(f"Session event parts: {[type(p).__name__ for e in session.events if e.content for p in e.content.parts]}") - # Check LLM request was modified has_inline_in_request = any( p.inline_data @@ -127,16 +118,7 @@ async def test_inline_data_replacement_in_extract_function(): for p in content.parts ) - print(f"\nLLM Request:") - print(f" - Has inline_data: {has_inline_in_request}") - print(f" - Has placeholder: {has_placeholder_in_request}") - - for content in llm_request.contents: - for i, part in enumerate(content.parts): - if part.text: - print(f" - Part {i} text: {part.text[:50]}") - - # Check session events were modified (THIS IS THE KEY FIX) + # Check session events were modified user_events = [e for e in session.events if e.content and e.content.role == 'user'] assert len(user_events) > 0, "Should have user events" @@ -147,28 +129,16 @@ async def test_inline_data_replacement_in_extract_function(): for p in first_user_event.content.parts ) - print(f"\nSession Events:") - print(f" - Has inline_data: {has_inline_in_session}") - print(f" - Has placeholder: {has_placeholder_in_session}") - - for i, part in enumerate(first_user_event.content.parts): - if part.text: - print(f" - Part {i} text: {part.text[:50]}") - if part.inline_data: - print(f" - Part {i} has inline_data of size: {len(part.inline_data.data)}") - # Assertions for LLM request assert not has_inline_in_request, "inline_data should be replaced in LLM request" assert has_placeholder_in_request, "Placeholder should be present in LLM request" - # Critical assertions for session events (THE FIX) - assert not has_inline_in_session, "inline_data should be replaced in session.events (FIX REQUIRED)" - assert has_placeholder_in_session, "Placeholder should be present in session.events (FIX REQUIRED)" + # Critical assertions for session events + assert not has_inline_in_session, "inline_data should be replaced in session.events" + assert has_placeholder_in_session, "Placeholder should be present in session.events" # Test that files were extracted assert len(result_files) >= 1, "At least one file should be extracted" - - print("\n=== TEST PASSED ===") @pytest.mark.asyncio @@ -222,7 +192,6 @@ async def test_persistence_across_simulated_turns(): ) initial_file_count = len(files_1) - print(f"\nTurn 1: Extracted {initial_file_count} files") # Verify session was modified user_events = [e for e in session.events if e.content and e.content.role == 'user'] @@ -231,9 +200,8 @@ async def test_persistence_across_simulated_turns(): for e in user_events for p in e.content.parts ) - print(f"Turn 1: Session has inline_data after processing: {has_inline_after_turn1}") - # Turn 2: User sends follow-up (simulating new LLM request from session history) + # Turn 2: User sends follow-up user_content_2 = types.Content( role='user', parts=[types.Part(text="What is the sum?")] @@ -246,8 +214,7 @@ async def test_persistence_across_simulated_turns(): ) session.events.append(user_event_2) - # Create new LLM request with ALL session events (this is what happens in real flow) - # Deep copy to simulate what base_llm_flow.py does + # Create new LLM request with ALL session events (simulating real flow) all_contents = [] for event in session.events: if event.content: @@ -255,15 +222,12 @@ async def test_persistence_across_simulated_turns(): llm_request_2 = LlmRequest(contents=all_contents) - print(f"\nTurn 2: LLM request has {len(llm_request_2.contents)} contents") - - # Check if inline_data reappeared in the request + # Check if inline_data reappeared has_inline_before_process = any( p.inline_data for content in llm_request_2.contents for p in content.parts ) - print(f"Turn 2: LLM request has inline_data BEFORE processing: {has_inline_before_process}") # Process turn 2 files_2 = _extract_and_replace_inline_files( @@ -273,16 +237,8 @@ async def test_persistence_across_simulated_turns(): ) final_file_count = len(code_executor_context.get_input_files()) - print(f"Turn 2: Total files in context: {final_file_count}") - - # Critical assertion: if session.events were properly modified in turn 1, - # then turn 2 should NOT see inline_data (it should already be replaced) - if has_inline_before_process: - print("\n⚠️ FAIL: inline_data reappeared in turn 2 (FIX NOT WORKING)") - print("This means session.events were not properly modified in turn 1") - else: - print("\n✓ PASS: inline_data did not reappear in turn 2 (FIX WORKING)") + # Critical assertion assert not has_inline_before_process, \ "inline_data should NOT reappear in turn 2 if session.events were properly modified"