44from typing import Any , Callable , Optional , TypeVar , Union , cast
55
66from prefect import Flow as PrefectFlow
7+ from prefect import Task as PrefectTask
78from prefect .utilities .asyncutils import run_coro_as_sync
89from typing_extensions import ParamSpec
910
@@ -34,7 +35,7 @@ def flow(
3435 timeout_seconds : Optional [Union [float , int ]] = None ,
3536 prefect_kwargs : Optional [dict [str , Any ]] = None ,
3637 context_kwargs : Optional [list [str ]] = None ,
37- ** kwargs : Optional [ dict [ str , Any ]] ,
38+ ** kwargs : Any ,
3839) -> Callable [[Callable [P , R ]], PrefectFlow [P , R ]]:
3940 """
4041 A decorator that wraps a function as a ControlFlow flow.
@@ -75,13 +76,15 @@ def flow(
7576 sig = inspect .signature (fn )
7677
7778 def create_flow_context (bound_args ):
78- flow_kwargs = kwargs .copy ()
79+ flow_kwargs : dict [ str , Any ] = kwargs .copy ()
7980 if thread is not None :
80- flow_kwargs . setdefault ( "thread_id" , thread ) # type: ignore
81+ flow_kwargs [ "thread_id" ] = thread
8182 if tools is not None :
82- flow_kwargs . setdefault ( "tools" , tools ) # type: ignore
83+ flow_kwargs [ "tools" ] = tools
8384 if default_agent is not None :
84- flow_kwargs .setdefault ("default_agent" , default_agent ) # type: ignore
85+ flow_kwargs ["default_agent" ] = default_agent
86+
87+ flow_kwargs .update (kwargs )
8588
8689 context = {}
8790 if context_kwargs :
@@ -117,17 +120,19 @@ def wrapper(*wrapper_args, **wrapper_kwargs):
117120 ):
118121 return fn (* wrapper_args , ** wrapper_kwargs )
119122
120- prefect_wrapper = prefect_flow (
121- timeout_seconds = timeout_seconds ,
122- retries = retries ,
123- retry_delay_seconds = retry_delay_seconds ,
124- ** (prefect_kwargs or {}),
125- )(wrapper )
126- return cast (Callable [[Callable [P , R ]], PrefectFlow [P , R ]], prefect_wrapper )
123+ return cast (
124+ Callable [[Callable [P , R ]], PrefectFlow [P , R ]],
125+ prefect_flow (
126+ timeout_seconds = timeout_seconds ,
127+ retries = retries ,
128+ retry_delay_seconds = retry_delay_seconds ,
129+ ** (prefect_kwargs or {}),
130+ )(wrapper ),
131+ )
127132
128133
129134def task (
130- fn : Optional [Callable [..., Any ]] = None ,
135+ fn : Optional [Callable [P , R ]] = None ,
131136 * ,
132137 objective : Optional [str ] = None ,
133138 instructions : Optional [str ] = None ,
@@ -138,8 +143,8 @@ def task(
138143 retries : Optional [int ] = None ,
139144 retry_delay_seconds : Optional [Union [float , int ]] = None ,
140145 timeout_seconds : Optional [Union [float , int ]] = None ,
141- ** task_kwargs : Optional [ dict [ str , Any ]] ,
142- ):
146+ ** task_kwargs : Any ,
147+ ) -> Callable [[ Callable [ P , R ]], PrefectTask [ P , R ]] :
143148 """
144149 A decorator that turns a Python function into a Task. The Task objective is
145150 set to the function name, and the instructions are set to the function
@@ -162,78 +167,68 @@ def task(
162167 callable: The wrapped function or a new task decorator if `fn` is not provided.
163168 """
164169
165- if fn is None :
166- return functools .partial (
167- task ,
168- objective = objective ,
169- instructions = instructions ,
170- name = name ,
171- agents = agents ,
172- tools = tools ,
173- interactive = interactive ,
174- retries = retries ,
175- retry_delay_seconds = retry_delay_seconds ,
176- timeout_seconds = timeout_seconds ,
177- ** task_kwargs ,
178- )
179-
180- sig = inspect .signature (fn )
181-
182- if name is None :
183- name = fn .__name__
184-
185- if objective is None :
186- objective = fn .__doc__ or ""
170+ def decorator (func : Callable [P , R ]) -> PrefectTask [P , R ]:
171+ sig = inspect .signature (func )
187172
188- result_type = fn .__annotations__ .get ("return" )
189-
190- def _get_task (* args , ** kwargs ) -> Task :
191- # first process callargs
192- bound = sig .bind (* args , ** kwargs )
193- bound .apply_defaults ()
194- context = bound .arguments .copy ()
195-
196- # call the function to see if it produces an updated objective
197- maybe_coro = fn (* args , ** kwargs )
198- if asyncio .iscoroutine (maybe_coro ):
199- result = run_coro_as_sync (maybe_coro )
173+ if name is None :
174+ task_name = func .__name__
200175 else :
201- result = maybe_coro
202- if result is not None :
203- context ["Additional context" ] = result
176+ task_name = name
204177
205- return Task (
206- objective = objective ,
207- instructions = instructions ,
208- name = name ,
209- agents = agents ,
210- context = context ,
211- result_type = result_type ,
212- interactive = interactive or False ,
213- tools = tools or [],
214- ** task_kwargs ,
215- )
178+ if objective is None :
179+ task_objective = func .__doc__ or ""
180+ else :
181+ task_objective = objective
216182
217- if asyncio . iscoroutinefunction ( fn ):
183+ result_type = func . __annotations__ . get ( "return" )
218184
219- @functools .wraps (fn )
220- async def wrapper (* args , ** kwargs ):
221- task = _get_task (* args , ** kwargs )
222- return await task .run_async ()
223- else :
185+ def _get_task (* args , ** kwargs ) -> Task :
186+ bound = sig .bind (* args , ** kwargs )
187+ bound .apply_defaults ()
188+ context = bound .arguments .copy ()
189+
190+ maybe_coro = func (* args , ** kwargs )
191+ if asyncio .iscoroutine (maybe_coro ):
192+ result = run_coro_as_sync (maybe_coro )
193+ else :
194+ result = maybe_coro
195+ if result is not None :
196+ context ["Additional context" ] = result
197+
198+ return Task (
199+ objective = task_objective ,
200+ instructions = instructions ,
201+ name = task_name ,
202+ agents = agents ,
203+ context = context ,
204+ result_type = result_type ,
205+ interactive = interactive or False ,
206+ tools = tools or [],
207+ ** task_kwargs ,
208+ )
209+
210+ if asyncio .iscoroutinefunction (func ):
211+
212+ @functools .wraps (func )
213+ async def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> R : # type: ignore
214+ task = _get_task (* args , ** kwargs )
215+ return await task .run_async () # type: ignore
216+ else :
224217
225- @functools .wraps (fn )
226- def wrapper (* args , ** kwargs ) :
227- task = _get_task (* args , ** kwargs )
228- return task .run ()
218+ @functools .wraps (func )
219+ def wrapper (* args : P . args , ** kwargs : P . kwargs ) -> R :
220+ task = _get_task (* args , ** kwargs )
221+ return task .run () # type: ignore
229222
230- prefect_wrapper = prefect_task (
231- timeout_seconds = timeout_seconds ,
232- retries = retries ,
233- retry_delay_seconds = retry_delay_seconds ,
234- )(wrapper )
223+ prefect_wrapper = prefect_task (
224+ timeout_seconds = timeout_seconds ,
225+ retries = retries ,
226+ retry_delay_seconds = retry_delay_seconds ,
227+ )(wrapper )
235228
236- # store the `as_task` method for loading the task object
237- prefect_wrapper . as_task = _get_task
229+ setattr ( prefect_wrapper , "as_task" , _get_task )
230+ return cast ( PrefectTask [ P , R ], prefect_wrapper )
238231
239- return cast (Callable [[Callable [..., Any ]], Task ], prefect_wrapper )
232+ if fn is None :
233+ return decorator
234+ return decorator (fn ) # type: ignore
0 commit comments