diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 5f87879f80..e94b591cef 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -153,6 +153,9 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) + raise ValueError( + "The context window was exceeded even after 3 attempts to truncate the trajectory." + ) async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): for _ in range(3): @@ -164,6 +167,9 @@ async def _async_call_with_potential_trajectory_truncation(self, module, traject except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) + raise ValueError( + "The context window was exceeded even after 3 attempts to truncate the trajectory." + ) def truncate_trajectory(self, trajectory): """Truncates the trajectory so that it fits in the context window. diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 09fd1c7c85..5295d5cef1 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -204,6 +204,39 @@ def mock_react(**kwargs): assert result.output_text == "Final output" +@pytest.mark.asyncio +async def test_context_window_exceeded_after_retries(): + def echo(text: str) -> str: + return f"Echoed: {text}" + + react = dspy.ReAct("input_text -> output_text", tools=[echo]) + + def mock_react(**kwargs): + raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider") + + react.react = mock_react + react.extract = lambda **kwargs: dspy.Prediction(output_text="Fallback output") + + # Test sync version + result = react(input_text="test input") + assert result.trajectory == {} + assert result.output_text == "Fallback output" + + # Test async version + async def mock_react_async(**kwargs): + raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider") + + async def mock_extract_async(**kwargs): + return dspy.Prediction(output_text="Fallback output") + + react.react.acall = mock_react_async + react.extract.acall = mock_extract_async + + result = await react.acall(input_text="test input") + assert result.trajectory == {} + assert result.output_text == "Fallback output" + + def test_error_retry(): # --- a tiny tool that always fails ------------------------------------- def foo(a, b):