Skip to content

Commit 2e874a7

Browse files
committed
Improve UDF cancellation
1 parent 24eb671 commit 2e874a7

File tree

5 files changed

+49
-19
lines changed

5 files changed

+49
-19
lines changed

singlestoredb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
DataError, ManagementError,
2626
)
2727
from .management import (
28-
manage_cluster, manage_workspaces, manage_files,
28+
manage_cluster, manage_workspaces, manage_files, manage_regions,
2929
)
3030
from .types import (
3131
Date, Time, Timestamp, DateFromTicks, TimeFromTicks, TimestampFromTicks,

singlestoredb/functions/ext/asgi.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,12 @@ def cancel_on_event(
285285
286286
"""
287287
if cancel_event.is_set():
288-
raise asyncio.CancelledError('Function call was cancelled')
288+
task = asyncio.current_task()
289+
if task is not None:
290+
task.cancel()
291+
raise asyncio.CancelledError(
292+
'Function call was cancelled by client',
293+
)
289294

290295

291296
def build_udf_endpoint(
@@ -314,19 +319,21 @@ def build_udf_endpoint(
314319

315320
async def do_func(
316321
cancel_event: threading.Event,
322+
finished_event: threading.Event,
317323
timer: Timer,
318324
row_ids: Sequence[int],
319325
rows: Sequence[Sequence[Any]],
320326
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
321327
'''Call function on given rows of data.'''
322328
out = []
323-
with timer('call_function'):
329+
async with timer('call_function'):
324330
for row in rows:
325331
cancel_on_event(cancel_event)
326332
if is_async:
327333
out.append(await func(*row))
328334
else:
329335
out.append(func(*row))
336+
finished_event.set()
330337
return row_ids, list(zip(out))
331338

332339
return do_func
@@ -360,6 +367,7 @@ def build_vector_udf_endpoint(
360367

361368
async def do_func(
362369
cancel_event: threading.Event,
370+
finished_event: threading.Event,
363371
timer: Timer,
364372
row_ids: Sequence[int],
365373
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
@@ -371,7 +379,7 @@ async def do_func(
371379
row_ids = array_cls(row_ids)
372380

373381
# Call the function with `cols` as the function parameters
374-
with timer('call_function'):
382+
async with timer('call_function'):
375383
if cols and cols[0]:
376384
if is_async:
377385
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
@@ -383,6 +391,7 @@ async def do_func(
383391
else:
384392
out = func()
385393

394+
finished_event.set()
386395
cancel_on_event(cancel_event)
387396

388397
# Single masked value
@@ -425,6 +434,7 @@ def build_tvf_endpoint(
425434

426435
async def do_func(
427436
cancel_event: threading.Event,
437+
finished_event: threading.Event,
428438
timer: Timer,
429439
row_ids: Sequence[int],
430440
rows: Sequence[Sequence[Any]],
@@ -433,7 +443,7 @@ async def do_func(
433443
out_ids: List[int] = []
434444
out = []
435445
# Call function on each row of data
436-
with timer('call_function'):
446+
async with timer('call_function'):
437447
for i, row in zip(row_ids, rows):
438448
cancel_on_event(cancel_event)
439449
if is_async:
@@ -442,6 +452,7 @@ async def do_func(
442452
res = func(*row)
443453
out.extend(as_list_of_tuples(res))
444454
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
455+
finished_event.set()
445456
return out_ids, out
446457

447458
return do_func
@@ -474,6 +485,7 @@ def build_vector_tvf_endpoint(
474485

475486
async def do_func(
476487
cancel_event: threading.Event,
488+
finished_event: threading.Event,
477489
timer: Timer,
478490
row_ids: Sequence[int],
479491
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
@@ -489,7 +501,7 @@ async def do_func(
489501
is_async = asyncio.iscoroutinefunction(func)
490502

491503
# Call function on each column of data
492-
with timer('call_function'):
504+
async with timer('call_function'):
493505
if cols and cols[0]:
494506
if is_async:
495507
func_res = await func(
@@ -505,6 +517,8 @@ async def do_func(
505517
else:
506518
func_res = func()
507519

520+
finished_event.set()
521+
508522
res = get_dataframe_columns(func_res)
509523

510524
cancel_on_event(cancel_event)
@@ -616,15 +630,11 @@ async def cancel_on_disconnect(
616630
)
617631

618632

619-
def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
633+
async def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
620634
"""Cancel all tasks."""
621635
for task in tasks:
622-
if task.done():
623-
continue
624-
try:
625-
task.cancel()
626-
except Exception:
627-
pass
636+
task.cancel()
637+
await asyncio.gather(*tasks, return_exceptions=True)
628638

629639

630640
def start_counter() -> float:
@@ -1027,17 +1037,24 @@ async def __call__(
10271037
result = []
10281038

10291039
cancel_event = threading.Event()
1040+
finished_event = threading.Event()
1041+
1042+
# Async functions don't need to set the finished event
1043+
if func_info['is_async']:
1044+
finished_event.set()
10301045

10311046
with timer('parse_input'):
10321047
inputs = input_handler['load']( # type: ignore
10331048
func_info['colspec'], b''.join(data),
10341049
)
10351050

10361051
func_task = asyncio.create_task(
1037-
func(cancel_event, timer, *inputs)
1052+
func(cancel_event, finished_event, timer, *inputs)
10381053
if func_info['is_async']
10391054
else to_thread(
1040-
lambda: asyncio.run(func(cancel_event, timer, *inputs)),
1055+
lambda: asyncio.run(
1056+
func(cancel_event, finished_event, timer, *inputs),
1057+
),
10411058
),
10421059
)
10431060
disconnect_task = asyncio.create_task(
@@ -1049,12 +1066,15 @@ async def __call__(
10491066

10501067
all_tasks += [func_task, disconnect_task, timeout_task]
10511068

1052-
with timer('function_wrapper'):
1069+
async with timer('function_wrapper'):
10531070
done, pending = await asyncio.wait(
10541071
all_tasks, return_when=asyncio.FIRST_COMPLETED,
10551072
)
10561073

1057-
cancel_all_tasks(pending)
1074+
await cancel_all_tasks(pending)
1075+
1076+
# Make sure threads finish before we proceed
1077+
finished_event.wait()
10581078

10591079
for task in done:
10601080
if task is disconnect_task:
@@ -1105,7 +1125,7 @@ async def __call__(
11051125
await send(self.error_response_dict)
11061126

11071127
finally:
1108-
cancel_all_tasks(all_tasks)
1128+
await cancel_all_tasks(all_tasks)
11091129

11101130
# Handle api reflection
11111131
elif method == 'GET' and path == self.show_create_function_path:

singlestoredb/functions/ext/timer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
111111
self.metrics.setdefault(timing_info['key'], 0)
112112
self.metrics[timing_info['key']] += elapsed
113113

114+
async def __aenter__(self) -> 'Timer':
115+
"""Async enter for async context manager support."""
116+
return self.__enter__()
117+
118+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
119+
"""Async exit for async context manager support."""
120+
self.__exit__(exc_type, exc_val, exc_tb)
121+
114122
def finish(self) -> None:
115123
"""Finish the current timing context and store the elapsed time."""
116124
if self._stack:

singlestoredb/management/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .cluster import manage_cluster
33
from .files import manage_files
44
from .manager import get_token
5+
from .region import manage_regions
56
from .workspace import get_organization
67
from .workspace import get_secret
78
from .workspace import get_stage

singlestoredb/tests/test_management.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from singlestoredb.management.job import Status
1515
from singlestoredb.management.job import TargetType
1616
from singlestoredb.management.region import Region
17+
from singlestoredb.management.region import RegionManager
1718
from singlestoredb.management.utils import NamedList
1819

1920

@@ -1591,5 +1592,5 @@ def test_no_manager(self):
15911592

15921593
# Verify from_dict class method
15931594
with self.assertRaises(s2.ManagementError) as cm:
1594-
Region.get_shared_tier_regions(None)
1595+
RegionManager.list_shared_tier_regions(None)
15951596
assert 'No workspace manager' in str(cm.exception)

0 commit comments

Comments
 (0)