@@ -562,6 +562,8 @@ async def on_complete():
562562 parts = await self ._process_function_tools (
563563 tool_calls , result_tool_name , run_context , result_schema
564564 )
565+ if any (isinstance (part , _messages .RetryPromptPart ) for part in parts ):
566+ self ._incr_result_retry (run_context )
565567 if parts :
566568 messages .append (_messages .ModelRequest (parts ))
567569 run_span .set_attribute ('all_messages' , messages )
@@ -1147,7 +1149,6 @@ async def _handle_structured_response(
11471149 result_data = result_tool .validate (call )
11481150 result_data = await self ._validate_result (result_data , run_context , call )
11491151 except _result .ToolRetryError as e :
1150- self ._incr_result_retry (run_context )
11511152 parts .append (e .tool_retry )
11521153 else :
11531154 final_result = _MarkFinalResult (result_data , call .tool_name )
@@ -1157,6 +1158,9 @@ async def _handle_structured_response(
11571158 tool_calls , final_result and final_result .tool_name , run_context , result_schema
11581159 )
11591160
1161+ if any (isinstance (part , _messages .RetryPromptPart ) for part in parts ):
1162+ self ._incr_result_retry (run_context )
1163+
11601164 return final_result , parts
11611165
11621166 async def _process_function_tools (
@@ -1210,7 +1214,7 @@ async def _process_function_tools(
12101214 )
12111215 )
12121216 else :
1213- parts .append (self ._unknown_tool (call .tool_name , run_context , result_schema ))
1217+ parts .append (self ._unknown_tool (call .tool_name , result_schema ))
12141218
12151219 # Run all tool tasks in parallel
12161220 if tasks :
@@ -1257,7 +1261,7 @@ async def _handle_streamed_response(
12571261 if tool := self ._function_tools .get (p .tool_name ):
12581262 tasks .append (asyncio .create_task (tool .run (p , run_context ), name = p .tool_name ))
12591263 else :
1260- parts .append (self ._unknown_tool (p .tool_name , run_context , result_schema ))
1264+ parts .append (self ._unknown_tool (p .tool_name , result_schema ))
12611265
12621266 if received_text and not tasks and not parts :
12631267 # Can only get here if self._allow_text_result returns `False` for the provided result_schema
@@ -1270,6 +1274,10 @@ async def _handle_streamed_response(
12701274 with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
12711275 task_results : Sequence [_messages .ModelRequestPart ] = await asyncio .gather (* tasks )
12721276 parts .extend (task_results )
1277+
1278+ if any (isinstance (part , _messages .RetryPromptPart ) for part in parts ):
1279+ self ._incr_result_retry (run_context )
1280+
12731281 return model_response , parts
12741282
12751283 async def _validate_result (
@@ -1307,10 +1315,8 @@ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_message
13071315 def _unknown_tool (
13081316 self ,
13091317 tool_name : str ,
1310- run_context : RunContext [AgentDepsT ],
13111318 result_schema : _result .ResultSchema [RunResultData ] | None ,
13121319 ) -> _messages .RetryPromptPart :
1313- self ._incr_result_retry (run_context )
13141320 names = list (self ._function_tools .keys ())
13151321 if result_schema :
13161322 names .extend (result_schema .tool_names ())
0 commit comments