11from __future__ import annotations
22
33import asyncio
4+ import json
45import logging
56import sys
67
78from pydantic import BaseModel
89from pydantic_ai import Agent
9- from pydantic_ai .models .anthropic import AnthropicModel
10- from pydantic_ai .result import RunResult
10+ from pydantic_ai .agent import AgentRunResult
1111from typing_extensions import Any , Dict , Optional , Union
1212
13+ from patchwork .common .client .llm .protocol import LlmClient
1314from patchwork .common .client .llm .utils import example_json_to_base_model
1415from patchwork .common .tools import Tool
1516from patchwork .common .utils .utils import mustache_render
1617
1718_COMPLETION_FLAG_ATTRIBUTE = "is_task_completed"
1819_MESSAGE_ATTRIBUTE = "message"
20+ DEFAULT_AGENT_EXAMPLE_JSON = f'{{"{ _MESSAGE_ATTRIBUTE } ":"message", "{ _COMPLETION_FLAG_ATTRIBUTE } ": false}}'
1921
2022
2123class AgentConfig (BaseModel ):
@@ -25,15 +27,23 @@ class Config:
2527 name : str
2628 tool_set : Dict [str , Tool ]
2729 system_prompt : str = ""
28- example_json : Union [
29- str , Dict [str , Any ]
30- ] = f'{{"{ _MESSAGE_ATTRIBUTE } ":"message", "{ _COMPLETION_FLAG_ATTRIBUTE } ": false}}'
30+ example_json : Union [str , Dict [str , Any ]] = DEFAULT_AGENT_EXAMPLE_JSON
31+
32+ def model_post_init (self , __context : Any ) -> None :
33+ if self .example_json == DEFAULT_AGENT_EXAMPLE_JSON :
34+ return
35+
36+ wanted = json .loads (self .example_json )
37+ default_wanted = json .loads (DEFAULT_AGENT_EXAMPLE_JSON )
38+ default_wanted .update (wanted )
39+ self .example_json = json .dumps (default_wanted )
3140
3241
3342class AgenticStrategyV2 :
3443 def __init__ (
3544 self ,
36- api_key : str ,
45+ model : str ,
46+ llm_client : LlmClient ,
3747 template_data : dict [str , str ],
3848 system_prompt_template : str ,
3949 user_prompt_template : str ,
@@ -44,25 +54,30 @@ def __init__(
4454 self .__limit = limit
4555 self .__template_data = template_data
4656 self .__user_prompt_template = user_prompt_template
47- model = AnthropicModel ("claude-3-5-sonnet-latest" , api_key = api_key )
4857 self .__summariser = Agent (
49- model ,
58+ llm_client ,
5059 system_prompt = mustache_render (system_prompt_template , self .__template_data ),
5160 result_type = example_json_to_base_model (example_json ),
52- model_settings = dict (parallel_tool_calls = False ),
61+ model_settings = dict (
62+ parallel_tool_calls = False ,
63+ model = model ,
64+ ),
5365 )
5466 self .__agents = []
5567 for agent_config in agent_configs :
5668 tools = []
5769 for tool in agent_config .tool_set .values ():
5870 tools .append (tool .to_pydantic_ai_function_tool ())
5971 agent = Agent (
60- model ,
72+ llm_client ,
6173 name = agent_config .name ,
6274 system_prompt = mustache_render (agent_config .system_prompt , self .__template_data ),
6375 tools = tools ,
6476 result_type = example_json_to_base_model (agent_config .example_json ),
65- model_settings = dict (parallel_tool_calls = False ),
77+ model_settings = dict (
78+ parallel_tool_calls = False ,
79+ model = model ,
80+ ),
6681 )
6782
6883 self .__agents .append (agent )
@@ -89,7 +104,7 @@ def execute(self, limit: Optional[int] = None) -> dict:
89104 message_history = None
90105 agent_output = None
91106 for i in range (limit or self .__limit or sys .maxsize ):
92- agent_output : RunResult [Any ] = loop .run_until_complete (
107+ agent_output : AgentRunResult [Any ] = loop .run_until_complete (
93108 agent .run (user_message , message_history = message_history )
94109 )
95110 message_history = agent_output .all_messages ()
@@ -107,10 +122,11 @@ def execute(self, limit: Optional[int] = None) -> dict:
107122 return dict ()
108123
109124 if len (agents_result ) == 1 :
125+ history = next (v for _ , v in agents_result .items ()).all_messages ()
110126 final_result = loop .run_until_complete (
111127 self .__summariser .run (
112128 "From the actions taken by the assistant. Please give me the result." ,
113- message_history = next ( v for _ , v in agents_result . items ()). all_messages () ,
129+ message_history = history ,
114130 )
115131 )
116132 else :
0 commit comments