1+ import asyncio
12import functools
23import inspect
34from typing import Any
1920]
2021
2122ReturnType = ParameterType
23+ UDFType = Callable [..., Any ]
2224
2325
2426def is_valid_type (obj : Any ) -> bool :
@@ -100,38 +102,50 @@ def _func(
100102 name : Optional [str ] = None ,
101103 args : Optional [ParameterType ] = None ,
102104 returns : Optional [ReturnType ] = None ,
103- ) -> Callable [..., Any ]:
105+ timeout : Optional [int ] = None ,
106+ ) -> UDFType :
104107 """Generic wrapper for UDF and TVF decorators."""
105108
106109 _singlestoredb_attrs = { # type: ignore
107110 k : v for k , v in dict (
108111 name = name ,
109112 args = expand_types (args ),
110113 returns = expand_types (returns ),
114+ timeout = timeout ,
111115 ).items () if v is not None
112116 }
113117
114118 # No func was specified, this is an uncalled decorator that will get
115119 # called later, so the wrapper much be created with the func passed
116120 # in at that time.
117121 if func is None :
118- def decorate (func : Callable [..., Any ] ) -> Callable [..., Any ] :
122+ def decorate (func : UDFType ) -> UDFType :
119123
120- def wrapper (* args : Any , ** kwargs : Any ) -> Callable [..., Any ]:
121- return func (* args , ** kwargs ) # type: ignore
124+ if asyncio .iscoroutinefunction (func ):
125+ async def async_wrapper (* args : Any , ** kwargs : Any ) -> UDFType :
126+ return await func (* args , ** kwargs ) # type: ignore
127+ async_wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
128+ return functools .wraps (func )(async_wrapper )
122129
123- wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
124-
125- return functools .wraps (func )(wrapper )
130+ else :
131+ def wrapper (* args : Any , ** kwargs : Any ) -> UDFType :
132+ return func (* args , ** kwargs ) # type: ignore
133+ wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
134+ return functools .wraps (func )(wrapper )
126135
127136 return decorate
128137
129- def wrapper (* args : Any , ** kwargs : Any ) -> Callable [..., Any ]:
130- return func (* args , ** kwargs ) # type: ignore
131-
132- wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
138+ if asyncio .iscoroutinefunction (func ):
139+ async def async_wrapper (* args : Any , ** kwargs : Any ) -> UDFType :
140+ return await func (* args , ** kwargs ) # type: ignore
141+ async_wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
142+ return functools .wraps (func )(async_wrapper )
133143
134- return functools .wraps (func )(wrapper )
144+ else :
145+ def wrapper (* args : Any , ** kwargs : Any ) -> UDFType :
146+ return func (* args , ** kwargs ) # type: ignore
147+ wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
148+ return functools .wraps (func )(wrapper )
135149
136150
137151def udf (
@@ -140,7 +154,8 @@ def udf(
140154 name : Optional [str ] = None ,
141155 args : Optional [ParameterType ] = None ,
142156 returns : Optional [ReturnType ] = None ,
143- ) -> Callable [..., Any ]:
157+ timeout : Optional [int ] = None ,
158+ ) -> UDFType :
144159 """
145160 Define a user-defined function (UDF).
146161
@@ -167,6 +182,9 @@ def udf(
167182 Specifies the return data type of the function. This parameter
168183 works the same way as `args`. If the function is a table-valued
169184 function, the return type should be a `Table` object.
185+ timeout : int, optional
186+ The timeout in seconds for the UDF execution. If not specified,
187+ the global default timeout is used.
170188
171189 Returns
172190 -------
@@ -178,4 +196,5 @@ def udf(
178196 name = name ,
179197 args = args ,
180198 returns = returns ,
199+ timeout = timeout ,
181200 )
0 commit comments