@@ -285,7 +285,12 @@ def cancel_on_event(
285285
286286 """
287287 if cancel_event .is_set ():
288- raise asyncio .CancelledError ('Function call was cancelled' )
288+ task = asyncio .current_task ()
289+ if task is not None :
290+ task .cancel ()
291+ raise asyncio .CancelledError (
292+ 'Function call was cancelled by client' ,
293+ )
289294
290295
291296def build_udf_endpoint (
@@ -314,19 +319,21 @@ def build_udf_endpoint(
314319
315320 async def do_func (
316321 cancel_event : threading .Event ,
322+ finished_event : threading .Event ,
317323 timer : Timer ,
318324 row_ids : Sequence [int ],
319325 rows : Sequence [Sequence [Any ]],
320326 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
321327 '''Call function on given rows of data.'''
322328 out = []
323- with timer ('call_function' ):
329+ async with timer ('call_function' ):
324330 for row in rows :
325331 cancel_on_event (cancel_event )
326332 if is_async :
327333 out .append (await func (* row ))
328334 else :
329335 out .append (func (* row ))
336+ finished_event .set ()
330337 return row_ids , list (zip (out ))
331338
332339 return do_func
@@ -360,6 +367,7 @@ def build_vector_udf_endpoint(
360367
361368 async def do_func (
362369 cancel_event : threading .Event ,
370+ finished_event : threading .Event ,
363371 timer : Timer ,
364372 row_ids : Sequence [int ],
365373 cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
@@ -371,7 +379,7 @@ async def do_func(
371379 row_ids = array_cls (row_ids )
372380
373381 # Call the function with `cols` as the function parameters
374- with timer ('call_function' ):
382+ async with timer ('call_function' ):
375383 if cols and cols [0 ]:
376384 if is_async :
377385 out = await func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
@@ -383,6 +391,7 @@ async def do_func(
383391 else :
384392 out = func ()
385393
394+ finished_event .set ()
386395 cancel_on_event (cancel_event )
387396
388397 # Single masked value
@@ -425,6 +434,7 @@ def build_tvf_endpoint(
425434
426435 async def do_func (
427436 cancel_event : threading .Event ,
437+ finished_event : threading .Event ,
428438 timer : Timer ,
429439 row_ids : Sequence [int ],
430440 rows : Sequence [Sequence [Any ]],
@@ -433,7 +443,7 @@ async def do_func(
433443 out_ids : List [int ] = []
434444 out = []
435445 # Call function on each row of data
436- with timer ('call_function' ):
446+ async with timer ('call_function' ):
437447 for i , row in zip (row_ids , rows ):
438448 cancel_on_event (cancel_event )
439449 if is_async :
@@ -442,6 +452,7 @@ async def do_func(
442452 res = func (* row )
443453 out .extend (as_list_of_tuples (res ))
444454 out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
455+ finished_event .set ()
445456 return out_ids , out
446457
447458 return do_func
@@ -474,6 +485,7 @@ def build_vector_tvf_endpoint(
474485
475486 async def do_func (
476487 cancel_event : threading .Event ,
488+ finished_event : threading .Event ,
477489 timer : Timer ,
478490 row_ids : Sequence [int ],
479491 cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
@@ -489,7 +501,7 @@ async def do_func(
489501 is_async = asyncio .iscoroutinefunction (func )
490502
491503 # Call function on each column of data
492- with timer ('call_function' ):
504+ async with timer ('call_function' ):
493505 if cols and cols [0 ]:
494506 if is_async :
495507 func_res = await func (
@@ -505,6 +517,8 @@ async def do_func(
505517 else :
506518 func_res = func ()
507519
520+ finished_event .set ()
521+
508522 res = get_dataframe_columns (func_res )
509523
510524 cancel_on_event (cancel_event )
@@ -616,15 +630,11 @@ async def cancel_on_disconnect(
616630 )
617631
618632
619- def cancel_all_tasks (tasks : Iterable [asyncio .Task [Any ]]) -> None :
633+ async def cancel_all_tasks (tasks : Iterable [asyncio .Task [Any ]]) -> None :
620634 """Cancel all tasks."""
621635 for task in tasks :
622- if task .done ():
623- continue
624- try :
625- task .cancel ()
626- except Exception :
627- pass
636+ task .cancel ()
637+ await asyncio .gather (* tasks , return_exceptions = True )
628638
629639
630640def start_counter () -> float :
@@ -1027,17 +1037,24 @@ async def __call__(
10271037 result = []
10281038
10291039 cancel_event = threading .Event ()
1040+ finished_event = threading .Event ()
1041+
1042+ # Async functions don't need to set the finished event
1043+ if func_info ['is_async' ]:
1044+ finished_event .set ()
10301045
10311046 with timer ('parse_input' ):
10321047 inputs = input_handler ['load' ]( # type: ignore
10331048 func_info ['colspec' ], b'' .join (data ),
10341049 )
10351050
10361051 func_task = asyncio .create_task (
1037- func (cancel_event , timer , * inputs )
1052+ func (cancel_event , finished_event , timer , * inputs )
10381053 if func_info ['is_async' ]
10391054 else to_thread (
1040- lambda : asyncio .run (func (cancel_event , timer , * inputs )),
1055+ lambda : asyncio .run (
1056+ func (cancel_event , finished_event , timer , * inputs ),
1057+ ),
10411058 ),
10421059 )
10431060 disconnect_task = asyncio .create_task (
@@ -1049,12 +1066,15 @@ async def __call__(
10491066
10501067 all_tasks += [func_task , disconnect_task , timeout_task ]
10511068
1052- with timer ('function_wrapper' ):
1069+ async with timer ('function_wrapper' ):
10531070 done , pending = await asyncio .wait (
10541071 all_tasks , return_when = asyncio .FIRST_COMPLETED ,
10551072 )
10561073
1057- cancel_all_tasks (pending )
1074+ await cancel_all_tasks (pending )
1075+
1076+ # Make sure threads finish before we proceed
1077+ finished_event .wait ()
10581078
10591079 for task in done :
10601080 if task is disconnect_task :
@@ -1105,7 +1125,7 @@ async def __call__(
11051125 await send (self .error_response_dict )
11061126
11071127 finally :
1108- cancel_all_tasks (all_tasks )
1128+ await cancel_all_tasks (all_tasks )
11091129
11101130 # Handle api reflection
11111131 elif method == 'GET' and path == self .show_create_function_path :
0 commit comments