88from typing import Any , Callable , Generic , Literal , Union , cast , get_args , get_origin
99
1010from pydantic import TypeAdapter , ValidationError
11- from typing_extensions import Self , TypeAliasType , TypedDict
11+ from typing_extensions import TypeAliasType , TypedDict , TypeVar
1212
1313from . import _utils , messages as _messages
1414from .exceptions import ModelRetry
15- from .result import ResultData , ResultValidatorFunc
16- from .tools import AgentDeps , RunContext , ToolDefinition
15+ from .result import ResultDataT , ResultDataT_inv , ResultValidatorFunc
16+ from .tools import AgentDepsT , RunContext , ToolDefinition
17+
18+ T = TypeVar ('T' )
19+ """An invariant TypeVar."""
1720
1821
1922@dataclass
20- class ResultValidator (Generic [AgentDeps , ResultData ]):
21- function : ResultValidatorFunc [AgentDeps , ResultData ]
23+ class ResultValidator (Generic [AgentDepsT , ResultDataT_inv ]):
24+ function : ResultValidatorFunc [AgentDepsT , ResultDataT_inv ]
2225 _takes_ctx : bool = field (init = False )
2326 _is_async : bool = field (init = False )
2427
@@ -28,10 +31,10 @@ def __post_init__(self):
2831
2932 async def validate (
3033 self ,
31- result : ResultData ,
34+ result : T ,
3235 tool_call : _messages .ToolCallPart | None ,
33- run_context : RunContext [AgentDeps ],
34- ) -> ResultData :
36+ run_context : RunContext [AgentDepsT ],
37+ ) -> T :
3538 """Validate a result but calling the function.
3639
3740 Args:
@@ -50,10 +53,10 @@ async def validate(
5053
5154 try :
5255 if self ._is_async :
53- function = cast (Callable [[Any ], Awaitable [ResultData ]], self .function )
56+ function = cast (Callable [[Any ], Awaitable [T ]], self .function )
5457 result_data = await function (* args )
5558 else :
56- function = cast (Callable [[Any ], ResultData ], self .function )
59+ function = cast (Callable [[Any ], T ], self .function )
5760 result_data = await _utils .run_in_executor (function , * args )
5861 except ModelRetry as r :
5962 m = _messages .RetryPromptPart (content = r .message )
@@ -74,17 +77,19 @@ def __init__(self, tool_retry: _messages.RetryPromptPart):
7477
7578
7679@dataclass
77- class ResultSchema (Generic [ResultData ]):
80+ class ResultSchema (Generic [ResultDataT ]):
7881 """Model the final response from an agent run.
7982
8083 Similar to `Tool` but for the final result of running an agent.
8184 """
8285
83- tools : dict [str , ResultTool [ResultData ]]
86+ tools : dict [str , ResultTool [ResultDataT ]]
8487 allow_text_result : bool
8588
8689 @classmethod
87- def build (cls , response_type : type [ResultData ], name : str , description : str | None ) -> Self | None :
90+ def build (
91+ cls : type [ResultSchema [T ]], response_type : type [T ], name : str , description : str | None
92+ ) -> ResultSchema [T ] | None :
8893 """Build a ResultSchema dataclass from a response type."""
8994 if response_type is str :
9095 return None
@@ -95,10 +100,10 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
95100 else :
96101 allow_text_result = False
97102
98- def _build_tool (a : Any , tool_name_ : str , multiple : bool ) -> ResultTool [ResultData ]:
99- return cast (ResultTool [ResultData ], ResultTool (a , tool_name_ , description , multiple ))
103+ def _build_tool (a : Any , tool_name_ : str , multiple : bool ) -> ResultTool [T ]:
104+ return cast (ResultTool [T ], ResultTool (a , tool_name_ , description , multiple ))
100105
101- tools : dict [str , ResultTool [ResultData ]] = {}
106+ tools : dict [str , ResultTool [T ]] = {}
102107 if args := get_union_args (response_type ):
103108 for i , arg in enumerate (args , start = 1 ):
104109 tool_name = union_tool_name (name , arg )
@@ -112,7 +117,7 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
112117
113118 def find_named_tool (
114119 self , parts : Iterable [_messages .ModelResponsePart ], tool_name : str
115- ) -> tuple [_messages .ToolCallPart , ResultTool [ResultData ]] | None :
120+ ) -> tuple [_messages .ToolCallPart , ResultTool [ResultDataT ]] | None :
116121 """Find a tool that matches one of the calls, with a specific name."""
117122 for part in parts :
118123 if isinstance (part , _messages .ToolCallPart ):
@@ -122,7 +127,7 @@ def find_named_tool(
122127 def find_tool (
123128 self ,
124129 parts : Iterable [_messages .ModelResponsePart ],
125- ) -> tuple [_messages .ToolCallPart , ResultTool [ResultData ]] | None :
130+ ) -> tuple [_messages .ToolCallPart , ResultTool [ResultDataT ]] | None :
126131 """Find a tool that matches one of the calls."""
127132 for part in parts :
128133 if isinstance (part , _messages .ToolCallPart ):
@@ -142,11 +147,11 @@ def tool_defs(self) -> list[ToolDefinition]:
142147
143148
144149@dataclass (init = False )
145- class ResultTool (Generic [ResultData ]):
150+ class ResultTool (Generic [ResultDataT ]):
146151 tool_def : ToolDefinition
147152 type_adapter : TypeAdapter [Any ]
148153
149- def __init__ (self , response_type : type [ResultData ], name : str , description : str | None , multiple : bool ):
154+ def __init__ (self , response_type : type [ResultDataT ], name : str , description : str | None , multiple : bool ):
150155 """Build a ResultTool dataclass from a response type."""
151156 assert response_type is not str , 'ResultTool does not support str as a response type'
152157
@@ -183,7 +188,7 @@ def __init__(self, response_type: type[ResultData], name: str, description: str
183188
184189 def validate (
185190 self , tool_call : _messages .ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
186- ) -> ResultData :
191+ ) -> ResultDataT :
187192 """Validate a result message.
188193
189194 Args:
0 commit comments