Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
trajectory = self.truncate_trajectory(trajectory)
raise ValueError(
"The context window was exceeded even after 3 attempts to truncate the trajectory."
)

async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
for _ in range(3):
Expand All @@ -164,6 +167,9 @@ async def _async_call_with_potential_trajectory_truncation(self, module, traject
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
trajectory = self.truncate_trajectory(trajectory)
raise ValueError(
"The context window was exceeded even after 3 attempts to truncate the trajectory."
)

def truncate_trajectory(self, trajectory):
"""Truncates the trajectory so that it fits in the context window.
Expand Down
33 changes: 33 additions & 0 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,39 @@ def mock_react(**kwargs):
assert result.output_text == "Final output"


@pytest.mark.asyncio
async def test_context_window_exceeded_after_retries():
def echo(text: str) -> str:
return f"Echoed: {text}"

react = dspy.ReAct("input_text -> output_text", tools=[echo])

def mock_react(**kwargs):
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")

react.react = mock_react
react.extract = lambda **kwargs: dspy.Prediction(output_text="Fallback output")

# Test sync version
result = react(input_text="test input")
assert result.trajectory == {}
assert result.output_text == "Fallback output"

# Test async version
async def mock_react_async(**kwargs):
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")

async def mock_extract_async(**kwargs):
return dspy.Prediction(output_text="Fallback output")

react.react.acall = mock_react_async
react.extract.acall = mock_extract_async

result = await react.acall(input_text="test input")
assert result.trajectory == {}
assert result.output_text == "Fallback output"


def test_error_retry():
# --- a tiny tool that always fails -------------------------------------
def foo(a, b):
Expand Down