Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions src/google/adk/flows/llm_flows/_code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

logger = logging.getLogger('google_adk.' + __name__)

_AVAILABLE_FILE_PREFIX = 'Available file:'


@dataclasses.dataclass
class DataFileUtil:
Expand Down Expand Up @@ -206,7 +208,7 @@ async def _run_pre_processor(
# memory. Meanwhile, mutate the inline data file to text part in session
# history from all turns.
all_input_files = _extract_and_replace_inline_files(
code_executor_context, llm_request
code_executor_context, llm_request, invocation_context
)

# [Step 2] Run Explore_Df code on the data files from the current turn. We
Expand Down Expand Up @@ -375,20 +377,42 @@ async def _run_post_processor(
def _extract_and_replace_inline_files(
code_executor_context: CodeExecutorContext,
llm_request: LlmRequest,
invocation_context: InvocationContext,
) -> list[File]:
"""Extracts and replaces inline files with file names in the LLM request."""
"""Extracts and replaces inline files with file names in the LLM request.

This function modifies both `llm_request.contents` for the current request
and `invocation_context.session.events` to ensure the replacement of inline
data with placeholders persists across conversation turns.

Args:
code_executor_context: Context containing code executor state.
llm_request: The LLM request to process.
invocation_context: Context containing session information.

Returns:
List of extracted File objects.
"""
all_input_files = code_executor_context.get_input_files()
saved_file_names = set(f.name for f in all_input_files)

# [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
# Track which session events need to be updated
events_to_update = {}

# Process input files from LlmRequest and cache them in CodeExecutor.
for i in range(len(llm_request.contents)):
content = llm_request.contents[i]
# Only process the user message.
if content.role != 'user' and not content.parts:
if content.role != 'user' or not content.parts:
continue

for j in range(len(content.parts)):
part = content.parts[j]

# Skip if already processed (already a placeholder)
if part.text and _AVAILABLE_FILE_PREFIX in part.text:
continue

# Skip if the inline data is not supported.
if (
not part.inline_data
Expand All @@ -399,21 +423,76 @@ def _extract_and_replace_inline_files(
# Replace the inline data file with a file name placeholder.
mime_type = part.inline_data.mime_type
file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension
llm_request.contents[i].parts[j] = types.Part(
text='\nAvailable file: `%s`\n' % file_name
)
placeholder_text = f'\n{_AVAILABLE_FILE_PREFIX} `{file_name}`\n'

# Store inline_data before replacing
inline_data_copy = part.inline_data

# Replace in llm_request
llm_request.contents[i].parts[j] = types.Part(text=placeholder_text)

# Find and update the corresponding session event
# to persist the replacement across turns
session = invocation_context.session
for event_idx, event in enumerate(session.events):
if (
event.content
and event.content.role == 'user'
and len(event.content.parts) > j
):
event_part = event.content.parts[j]
# Match by inline_data content (comparing mime_type, length, and data)
# Length check first for performance optimization
if (
event_part.inline_data
and event_part.inline_data.mime_type == mime_type
and len(event_part.inline_data.data) == len(inline_data_copy.data)
and event_part.inline_data.data == inline_data_copy.data
):
# Mark this event/part for update
if event_idx not in events_to_update:
events_to_update[event_idx] = {}
events_to_update[event_idx][j] = placeholder_text
break

# Add the inline data as input file to the code executor context.
file = File(
name=file_name,
content=CodeExecutionUtils.get_encoded_file_content(
part.inline_data.data
inline_data_copy.data
).decode(),
mime_type=mime_type,
)
if file_name not in saved_file_names:
code_executor_context.add_input_files([file])
all_input_files.append(file)
saved_file_names.add(file_name)

# Apply updates to session.events to persist across turns
session = invocation_context.session
for event_idx, parts_to_update in events_to_update.items():
event = session.events[event_idx]
# Create new parts list with replacements
updated_parts = list(event.content.parts)
for part_idx, placeholder_text in parts_to_update.items():
updated_parts[part_idx] = types.Part(text=placeholder_text)

# Create new content with updated parts
updated_content = types.Content(
role=event.content.role, parts=updated_parts
)

# Update the event in session (modify in place)
# Event is a Pydantic model, use model_copy() instead of dataclasses.replace()
session.events[event_idx] = event.model_copy(
update={'content': updated_content}
)

logger.debug(
'Replaced inline_data in session.events[%d] with placeholder: %s',
event_idx,
placeholder_text.strip(),
)

return all_input_files

Expand Down
247 changes: 247 additions & 0 deletions tests/unittests/flows/llm_flows/test_code_execution_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for optimize_data_file persistence across turns."""

import pytest
from typing import Dict
from google.adk.code_executors.base_code_executor import BaseCodeExecutor
from google.genai import types
from google.adk.code_executors.code_execution_utils import (
CodeExecutionInput,
CodeExecutionResult,
File,
)
from google.adk.code_executors.code_executor_context import CodeExecutorContext
from google.adk.models.llm_request import LlmRequest
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from pydantic import Field
import copy


class MockCodeExecutor(BaseCodeExecutor):
"""Mock executor for testing."""

# Define as Pydantic fields
injected_files: Dict[str, str] = Field(default_factory=dict)
execution_count: int = Field(default=0)

def execute_code(self, invocation_context, code_input: CodeExecutionInput):
"""Mock code execution."""
self.execution_count += 1
# Store files if they're new
for file in code_input.input_files:
if file.name not in self.injected_files:
self.injected_files[file.name] = file.content

return CodeExecutionResult(
stdout=f"Executed: {len(code_input.input_files)} files available",
stderr="",
output_files=[]
)


@pytest.mark.asyncio
async def test_inline_data_replacement_in_extract_function():
"""Test that _extract_and_replace_inline_files modifies both request and session."""
from google.adk.flows.llm_flows._code_execution import _extract_and_replace_inline_files

# Create CSV data
csv_data = b"name,value\nA,100\nB,200"

# Create a mock session with events
session = Session(
id='test_session',
app_name='test_app',
user_id='test_user',
state={},
events=[]
)

# Create user content with inline_data
user_content = types.Content(
role='user',
parts=[
types.Part(text="Process this"),
types.Part(inline_data=types.Blob(mime_type='text/csv', data=csv_data))
]
)

# Add to session events
from google.adk.events.event import Event
user_event = Event(
invocation_id='test_inv',
author='user',
content=user_content,
)
session.events.append(user_event)

# Create LLM request
llm_request = LlmRequest(contents=[copy.deepcopy(user_content)])

# Create code executor context
code_executor_context = CodeExecutorContext(session.state)

# Create mock invocation context
from unittest.mock import Mock
invocation_context = Mock()
invocation_context.session = session

# Call the function we're testing
result_files = _extract_and_replace_inline_files(
code_executor_context,
llm_request,
invocation_context
)

# Check LLM request was modified
has_inline_in_request = any(
p.inline_data
for content in llm_request.contents
for p in content.parts
)
has_placeholder_in_request = any(
'Available file:' in (p.text or '')
for content in llm_request.contents
for p in content.parts
)

# Check session events were modified
user_events = [e for e in session.events if e.content and e.content.role == 'user']
assert len(user_events) > 0, "Should have user events"

first_user_event = user_events[0]
has_inline_in_session = any(p.inline_data for p in first_user_event.content.parts)
has_placeholder_in_session = any(
'Available file:' in (p.text or '')
for p in first_user_event.content.parts
)

# Assertions for LLM request
assert not has_inline_in_request, "inline_data should be replaced in LLM request"
assert has_placeholder_in_request, "Placeholder should be present in LLM request"

# Critical assertions for session events
assert not has_inline_in_session, "inline_data should be replaced in session.events"
assert has_placeholder_in_session, "Placeholder should be present in session.events"

# Test that files were extracted
assert len(result_files) >= 1, "At least one file should be extracted"


@pytest.mark.asyncio
async def test_persistence_across_simulated_turns():
"""Test that on a second 'turn', inline_data doesn't reappear."""
from google.adk.flows.llm_flows._code_execution import _extract_and_replace_inline_files

# Create CSV data
csv_data = b"name,value\nA,100\nB,200"

# Create a session
session = Session(
id='test_session',
app_name='test_app',
user_id='test_user',
state={},
events=[]
)

# Turn 1: User sends CSV
user_content_1 = types.Content(
role='user',
parts=[
types.Part(text="Process this"),
types.Part(inline_data=types.Blob(mime_type='text/csv', data=csv_data))
]
)

from google.adk.events.event import Event
user_event_1 = Event(
invocation_id='test_inv_1',
author='user',
content=user_content_1,
)
session.events.append(user_event_1)

# Create code executor context
code_executor_context = CodeExecutorContext(session.state)

# Mock invocation context
from unittest.mock import Mock
invocation_context = Mock()
invocation_context.session = session

# Process turn 1
llm_request_1 = LlmRequest(contents=[copy.deepcopy(user_content_1)])
files_1 = _extract_and_replace_inline_files(
code_executor_context,
llm_request_1,
invocation_context
)

initial_file_count = len(files_1)

# Verify session was modified
user_events = [e for e in session.events if e.content and e.content.role == 'user']
has_inline_after_turn1 = any(
p.inline_data
for e in user_events
for p in e.content.parts
)

# Turn 2: User sends follow-up
user_content_2 = types.Content(
role='user',
parts=[types.Part(text="What is the sum?")]
)

user_event_2 = Event(
invocation_id='test_inv_2',
author='user',
content=user_content_2,
)
session.events.append(user_event_2)

# Create new LLM request with ALL session events (simulating real flow)
all_contents = []
for event in session.events:
if event.content:
all_contents.append(copy.deepcopy(event.content))

llm_request_2 = LlmRequest(contents=all_contents)

# Check if inline_data reappeared
has_inline_before_process = any(
p.inline_data
for content in llm_request_2.contents
for p in content.parts
)

# Process turn 2
files_2 = _extract_and_replace_inline_files(
code_executor_context,
llm_request_2,
invocation_context
)

final_file_count = len(code_executor_context.get_input_files())

# Critical assertion
assert not has_inline_before_process, \
"inline_data should NOT reappear in turn 2 if session.events were properly modified"


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])