diff --git a/pyproject.toml b/pyproject.toml index 1dbb766e2..d922b1134 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" license-files = ["LICENSE"] keywords = ["temporal", "workflow"] dependencies = [ - "nexus-rpc==1.2.0", + "nexus-rpc==1.3.0", "protobuf>=3.20,<7.0.0", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index b1a0d6a06..ef1e52bb2 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -2,19 +2,20 @@ from __future__ import annotations +import dataclasses from collections.abc import Callable, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, - Dict, + Generic, NoReturn, - Optional, - Type, TypeAlias, + TypeVar, cast, ) +import nexusrpc.handler import opentelemetry.baggage.propagation import opentelemetry.context import opentelemetry.context.context @@ -54,6 +55,8 @@ _CarrierDict: TypeAlias = dict[str, opentelemetry.propagators.textmap.CarrierValT] +_ContextT = TypeVar("_ContextT", bound=nexusrpc.handler.OperationContext) + class TracingInterceptor(temporalio.client.Interceptor, temporalio.worker.Interceptor): """Interceptor that supports client and worker OpenTelemetry span creation @@ -133,6 +136,14 @@ def workflow_interceptor_class( ) return TracingWorkflowInboundInterceptor + def intercept_nexus_operation( + self, next: temporalio.worker.NexusOperationInboundInterceptor + ) -> temporalio.worker.NexusOperationInboundInterceptor: + """Implementation of + :py:meth:`temporalio.worker.Interceptor.intercept_nexus_operation`. + """ + return _TracingNexusOperationInboundInterceptor(next, self) + def _context_to_headers( self, headers: Mapping[str, temporalio.api.common.v1.Payload] ) -> Mapping[str, temporalio.api.common.v1.Payload]: @@ -166,7 +177,8 @@ def _start_as_current_span( name: str, *, attributes: opentelemetry.util.types.Attributes, - input: _InputWithHeaders | None = None, + input_with_headers: _InputWithHeaders | None = None, + input_with_ctx: _InputWithOperationContext | None = None, kind: opentelemetry.trace.SpanKind, context: Context | None = None, ) -> Iterator[None]: @@ -179,8 +191,19 @@ def _start_as_current_span( context=context, set_status_on_exception=False, ) as span: - if input: - input.headers = self._context_to_headers(input.headers) + if input_with_headers: + input_with_headers.headers = self._context_to_headers( + input_with_headers.headers + ) + if input_with_ctx: + carrier: _CarrierDict = {} + self.text_map_propagator.inject(carrier) + input_with_ctx.ctx = dataclasses.replace( + input_with_ctx.ctx, + headers=_carrier_to_nexus_headers( + carrier, input_with_ctx.ctx.headers + ), + ) try: yield None except Exception as exc: @@ -258,7 +281,7 @@ async def start_workflow( with self.root._start_as_current_span( f"{prefix}:{input.workflow}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().start_workflow(input) @@ -267,7 +290,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A with self.root._start_as_current_span( f"QueryWorkflow:{input.query}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().query_workflow(input) @@ -278,7 +301,7 @@ async def signal_workflow( with self.root._start_as_current_span( f"SignalWorkflow:{input.signal}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().signal_workflow(input) @@ -289,7 +312,7 @@ async def start_workflow_update( with self.root._start_as_current_span( f"StartWorkflowUpdate:{input.update}", attributes={"temporalWorkflowID": input.id}, - input=input, + input_with_headers=input, kind=opentelemetry.trace.SpanKind.CLIENT, ): return await super().start_workflow_update(input) @@ -306,7 +329,7 @@ async def start_update_with_start_workflow( with self.root._start_as_current_span( f"StartUpdateWithStartWorkflow:{input.start_workflow_input.workflow}", attributes=attrs, - input=input.start_workflow_input, + input_with_headers=input.start_workflow_input, kind=opentelemetry.trace.SpanKind.CLIENT, ): otel_header = input.start_workflow_input.headers.get(self.root.header_key) @@ -345,10 +368,60 @@ async def execute_activity( return await super().execute_activity(input) +class _TracingNexusOperationInboundInterceptor( + temporalio.worker.NexusOperationInboundInterceptor +): + def __init__( + self, + next: temporalio.worker.NexusOperationInboundInterceptor, + root: TracingInterceptor, + ) -> None: + super().__init__(next) + self._root = root + + def _context_from_nexus_headers(self, headers: Mapping[str, str]): + return self._root.text_map_propagator.extract(headers) + + async def execute_nexus_operation_start( + self, input: temporalio.worker.ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + with self._root._start_as_current_span( + f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", + context=self._context_from_nexus_headers(input.ctx.headers), + attributes={}, + input_with_ctx=input, + kind=opentelemetry.trace.SpanKind.SERVER, + ): + return await self.next.execute_nexus_operation_start(input) + + async def execute_nexus_operation_cancel( + self, input: temporalio.worker.ExecuteNexusOperationCancelInput + ) -> None: + with self._root._start_as_current_span( + f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}", + context=self._context_from_nexus_headers(input.ctx.headers), + attributes={}, + input_with_ctx=input, + kind=opentelemetry.trace.SpanKind.SERVER, + ): + return await self.next.execute_nexus_operation_cancel(input) + + class _InputWithHeaders(Protocol): headers: Mapping[str, temporalio.api.common.v1.Payload] +class _InputWithStringHeaders(Protocol): + headers: Mapping[str, str] | None + + +class _InputWithOperationContext(Generic[_ContextT], Protocol): + ctx: _ContextT + + class _WorkflowExternFunctions(TypedDict): __temporal_opentelemetry_completed_span: Callable[ [_CompletedWorkflowSpanParams], _CarrierDict | None @@ -602,6 +675,7 @@ def _completed_span( *, link_context_carrier: _CarrierDict | None = None, add_to_outbound: _InputWithHeaders | None = None, + add_to_outbound_str: _InputWithStringHeaders | None = None, new_span_even_on_replay: bool = False, additional_attributes: opentelemetry.util.types.Attributes = None, exception: Exception | None = None, @@ -614,12 +688,14 @@ def _completed_span( # Create the span. First serialize current context to carrier. new_context_carrier: _CarrierDict = {} self.text_map_propagator.inject(new_context_carrier) + # Invoke info = temporalio.workflow.info() attributes: dict[str, opentelemetry.util.types.AttributeValue] = { "temporalWorkflowID": info.workflow_id, "temporalRunID": info.run_id, } + if additional_attributes: attributes.update(additional_attributes) updated_context_carrier = self._extern_functions[ @@ -640,10 +716,16 @@ def _completed_span( ) # Add to outbound if needed - if add_to_outbound and updated_context_carrier: - add_to_outbound.headers = self._context_carrier_to_headers( - updated_context_carrier, add_to_outbound.headers - ) + if updated_context_carrier: + if add_to_outbound: + add_to_outbound.headers = self._context_carrier_to_headers( + updated_context_carrier, add_to_outbound.headers + ) + + if add_to_outbound_str: + add_to_outbound_str.headers = _carrier_to_nexus_headers( + updated_context_carrier, add_to_outbound_str.headers + ) def _set_on_context( self, context: opentelemetry.context.Context @@ -722,6 +804,29 @@ def start_local_activity( ) return super().start_local_activity(input) + async def start_nexus_operation( + self, input: temporalio.worker.StartNexusOperationInput[Any, Any] + ) -> temporalio.workflow.NexusOperationHandle[Any]: + self.root._completed_span( + f"StartNexusOperation:{input.service}/{input.operation_name}", + kind=opentelemetry.trace.SpanKind.CLIENT, + add_to_outbound_str=input, + ) + + return await super().start_nexus_operation(input) + + +def _carrier_to_nexus_headers( + carrier: _CarrierDict, initial: Mapping[str, str] | None = None +) -> Mapping[str, str]: + out = {**initial} if initial else {} + for k, v in carrier.items(): + if isinstance(v, list): + out[k] = ",".join(v) + else: + out[k] = v + return out + class workflow: """Contains static methods that are safe to call from within a workflow. diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 1d7b2558e..8388be24d 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -6,11 +6,14 @@ ActivityOutboundInterceptor, ContinueAsNewInput, ExecuteActivityInput, + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, ExecuteWorkflowInput, HandleQueryInput, HandleSignalInput, HandleUpdateInput, Interceptor, + NexusOperationInboundInterceptor, SignalChildWorkflowInput, SignalExternalWorkflowInput, StartActivityInput, @@ -80,6 +83,7 @@ "ActivityOutboundInterceptor", "WorkflowInboundInterceptor", "WorkflowOutboundInterceptor", + "NexusOperationInboundInterceptor", "Plugin", # Interceptor input "ContinueAsNewInput", @@ -95,6 +99,8 @@ "StartLocalActivityInput", "StartNexusOperationInput", "WorkflowInterceptorClassInput", + "ExecuteNexusOperationStartInput", + "ExecuteNexusOperationCancelInput", # Advanced activity classes "SharedStateManager", "SharedHeartbeatSender", diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 338fa9286..d3b838679 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -66,6 +66,20 @@ def workflow_interceptor_class( """ return None + def intercept_nexus_operation( + self, next: NexusOperationInboundInterceptor + ) -> NexusOperationInboundInterceptor: + """Method called for intercepting a Nexus operation. + + Args: + next: The underlying inbound this interceptor + should delegate to. + + Returns: + The new interceptor that should be used for the Nexus operation. + """ + return next + @dataclass(frozen=True) class WorkflowInterceptorClassInput: @@ -465,3 +479,50 @@ async def start_nexus_operation( ) -> temporalio.workflow.NexusOperationHandle[OutputT]: """Called for every :py:func:`temporalio.workflow.NexusClient.start_operation` call.""" return await self.next.start_nexus_operation(input) + + +@dataclass +class ExecuteNexusOperationStartInput: + """Input for :pyt:meth:`NexusOperationInboundInterceptor.start_operation""" + + ctx: nexusrpc.handler.StartOperationContext + input: Any + + +@dataclass +class ExecuteNexusOperationCancelInput: + """Input for :pyt:meth:`NexusOperationInboundInterceptor.cancel_operation""" + + ctx: nexusrpc.handler.CancelOperationContext + token: str + + +class NexusOperationInboundInterceptor: + """Inbound interceptor to wrap Nexus operation starting and cancelling. + + This should be extended by any Nexus operation inbound interceptors. + """ + + def __init__(self, next: NexusOperationInboundInterceptor) -> None: + """Create the inbound interceptor. + + Args: + next: The next interceptor in the chain. The default implementation + of all calls is to delegate to the next interceptor. + """ + self.next = next + + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + """Called to start a Nexus operation""" + return await self.next.execute_nexus_operation_start(input) + + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: + """Called to cancel an in progress Nexus operation""" + return await self.next.execute_nexus_operation_cancel(input) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 281939398..bc06fbd4a 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -9,11 +9,13 @@ import threading from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass +from functools import reduce from typing import ( Any, NoReturn, ParamSpec, TypeVar, + cast, ) import google.protobuf.json_format @@ -38,7 +40,12 @@ from temporalio.nexus import Info, logger from temporalio.service import RPCError, RPCStatusCode -from ._interceptor import Interceptor +from ._interceptor import ( + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, + Interceptor, + NexusOperationInboundInterceptor, +) _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" @@ -73,12 +80,15 @@ def __init__( self._task_queue = task_queue self._metric_meter = metric_meter + middleware = _NexusMiddlewareForInterceptors(interceptors) # If an executor is provided, we wrap the executor with one that will # copy the contextvars.Context to the thread on submit handler_executor = _ContextPropagatingExecutor(executor) if executor else None + self._handler = Handler( + service_handlers, handler_executor, middleware=[middleware] + ) - self._handler = Handler(service_handlers, handler_executor) self._data_converter = data_converter # TODO(nexus-preview): interceptors self._interceptors = interceptors @@ -605,6 +615,69 @@ def cancel(self, reason: str) -> bool: return True +class _NexusOperationHandlerForInterceptor( + nexusrpc.handler.MiddlewareSafeOperationHandler +): + def __init__(self, next_interceptor: NexusOperationInboundInterceptor): + self._next_interceptor = next_interceptor + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: Any + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._next_interceptor.execute_nexus_operation_start( + ExecuteNexusOperationStartInput(ctx, input) + ) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + return await self._next_interceptor.execute_nexus_operation_cancel( + ExecuteNexusOperationCancelInput(ctx, token) + ) + + +class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): + def __init__(self, handler: nexusrpc.handler.MiddlewareSafeOperationHandler): # pyright: ignore[reportMissingSuperCall] + self._handler = handler + + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + return await self._handler.start(input.ctx, input.input) + + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: + return await self._handler.cancel(input.ctx, input.token) + + +class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): + def __init__(self, interceptors: Sequence[Interceptor]) -> None: + self._interceptors = interceptors + + def intercept( + self, + ctx: nexusrpc.handler.OperationContext, + next: nexusrpc.handler.MiddlewareSafeOperationHandler, + ) -> nexusrpc.handler.MiddlewareSafeOperationHandler: + inbound = reduce( + lambda impl, _next: _next.intercept_nexus_operation(impl), + reversed(self._interceptors), + cast( + NexusOperationInboundInterceptor, + _NexusOperationInboundInterceptorImpl(next), + ), + ) + + return _NexusOperationHandlerForInterceptor(inbound) + + _P = ParamSpec("_P") _T = TypeVar("_T") diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 0308da2b8..c85cf1d1f 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -13,6 +13,7 @@ from datetime import timedelta from typing import Dict, List, Optional, cast +import nexusrpc import opentelemetry.context import pytest from opentelemetry import baggage, context @@ -21,7 +22,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode, get_tracer -from temporalio import activity, workflow +from temporalio import activity, nexus, workflow from temporalio.client import Client, WithStartWorkflowOperation, WorkflowUpdateStage from temporalio.common import RetryPolicy, WorkflowIDConflictPolicy from temporalio.contrib.opentelemetry import ( @@ -29,10 +30,15 @@ TracingWorkflowInboundInterceptor, ) from temporalio.contrib.opentelemetry import workflow as otel_workflow -from temporalio.exceptions import ApplicationError, ApplicationErrorCategory +from temporalio.exceptions import ( + ApplicationError, + ApplicationErrorCategory, + NexusOperationError, +) from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers import LogCapturer +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @dataclass @@ -63,6 +69,7 @@ class TracingWorkflowAction: wait_until_signal_count: int = 0 wait_and_do_update: bool = False wait_and_do_start_with_update: bool = False + start_and_cancel_nexus_operation: bool = False @dataclass @@ -86,6 +93,29 @@ class TracingWorkflowActionContinueAsNew: param: TracingWorkflowParam +@workflow.defn +class ExpectCancelNexusWorkflow: + @workflow.run + async def run(self, input: str): + try: + await asyncio.wait_for(asyncio.Future(), 2) + except asyncio.TimeoutError: + raise ApplicationError("expected cancellation") + + +@nexusrpc.handler.service_handler +class InterceptedNexusService: + @nexus.workflow_run_operation + async def intercepted_operation( + self, ctx: nexus.WorkflowRunOperationContext, input: str + ) -> nexus.WorkflowHandle[None]: + return await ctx.start_workflow( + ExpectCancelNexusWorkflow.run, + input, + id=f"wf-{uuid.uuid4()}-{ctx.request_id}", + ) + + ready_for_update: asyncio.Semaphore ready_for_update_with_start: asyncio.Semaphore @@ -153,6 +183,22 @@ async def run(self, param: TracingWorkflowParam) -> None: if action.wait_and_do_start_with_update: ready_for_update_with_start.release() await workflow.wait_condition(lambda: self._did_update_with_start) + if action.start_and_cancel_nexus_operation: + nexus_client = workflow.create_nexus_client( + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + service=InterceptedNexusService, + ) + + nexus_handle = await nexus_client.start_operation( + operation=InterceptedNexusService.intercepted_operation, + input="nexus-workflow", + ) + nexus_handle.cancel() + + try: + await nexus_handle + except NexusOperationError: + pass async def _raise_on_non_replay(self) -> None: replaying = workflow.unsafe.is_replaying() @@ -411,6 +457,67 @@ async def test_opentelemetry_tracing_update_with_start( ] +async def test_opentelemetry_tracing_nexus(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1424" + ) + global ready_for_update_with_start + ready_for_update_with_start = asyncio.Semaphore(0) + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + # Create new client with tracer interceptor + client_config = client.config() + client_config["interceptors"] = [TracingInterceptor(tracer)] + client = Client(**client_config) + + task_queue = f"task-queue-{uuid.uuid4()}" + await create_nexus_endpoint(task_queue, client) + async with Worker( + client, + task_queue=task_queue, + workflows=[TracingWorkflow, ExpectCancelNexusWorkflow], + activities=[tracing_activity], + nexus_service_handlers=[InterceptedNexusService()], + # Needed so we can wait to send update at the right time + workflow_runner=UnsandboxedWorkflowRunner(), + ): + # Run workflow with various actions + workflow_id = f"workflow_{uuid.uuid4()}" + workflow_params = TracingWorkflowParam( + actions=[ + TracingWorkflowAction(start_and_cancel_nexus_operation=True), + ] + ) + handle = await client.start_workflow( + TracingWorkflow.run, + workflow_params, + id=workflow_id, + task_queue=task_queue, + ) + await handle.result() + + # Dump debug with attributes, but do string assertion test without + logging.debug( + "Spans:\n%s", + "\n".join(dump_spans(exporter.get_finished_spans(), with_attributes=False)), + ) + assert dump_spans(exporter.get_finished_spans(), with_attributes=False) == [ + "StartWorkflow:TracingWorkflow", + " RunWorkflow:TracingWorkflow", + " MyCustomSpan", + " StartNexusOperation:InterceptedNexusService/intercepted_operation", + " RunStartNexusOperationHandler:InterceptedNexusService/intercepted_operation", + " StartWorkflow:ExpectCancelNexusWorkflow", + " RunWorkflow:ExpectCancelNexusWorkflow", + " RunCancelNexusOperationHandler:InterceptedNexusService/intercepted_operation", + " CompleteWorkflow:TracingWorkflow", + ] + + def dump_spans( spans: Iterable[ReadableSpan], *, diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index 7746dce2d..a5616bb17 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -4,9 +4,10 @@ from datetime import timedelta from typing import Any, List, NoReturn, Optional, Tuple, Type +import nexusrpc import pytest -from temporalio import activity, workflow +from temporalio import activity, nexus, workflow from temporalio.client import Client, WorkflowUpdateFailedError from temporalio.exceptions import ApplicationError, NexusOperationError from temporalio.testing import WorkflowEnvironment @@ -15,22 +16,25 @@ ActivityOutboundInterceptor, ContinueAsNewInput, ExecuteActivityInput, + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, ExecuteWorkflowInput, HandleQueryInput, HandleSignalInput, HandleUpdateInput, Interceptor, + NexusOperationInboundInterceptor, SignalChildWorkflowInput, SignalExternalWorkflowInput, StartActivityInput, StartChildWorkflowInput, StartLocalActivityInput, + StartNexusOperationInput, Worker, WorkflowInboundInterceptor, WorkflowInterceptorClassInput, WorkflowOutboundInterceptor, ) -from temporalio.worker._interceptor import StartNexusOperationInput from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name interceptor_traces: list[tuple[str, Any]] = [] @@ -47,6 +51,11 @@ def workflow_interceptor_class( ) -> type[WorkflowInboundInterceptor] | None: return TracingWorkflowInboundInterceptor + def intercept_nexus_operation( + self, next: NexusOperationInboundInterceptor + ) -> NexusOperationInboundInterceptor: + return TracingNexusInboundInterceptor(next) + class TracingActivityInboundInterceptor(ActivityInboundInterceptor): def init(self, outbound: ActivityOutboundInterceptor) -> None: @@ -134,6 +143,50 @@ async def start_nexus_operation( return await super().start_nexus_operation(input) +class TracingNexusInboundInterceptor(NexusOperationInboundInterceptor): + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> ( + nexusrpc.handler.StartOperationResultSync[Any] + | nexusrpc.handler.StartOperationResultAsync + ): + interceptor_traces.append( + (f"nexus.start_operation.{input.ctx.service}.{input.ctx.operation}", input) + ) + return await super().execute_nexus_operation_start(input) + + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: + interceptor_traces.append( + (f"nexus.cancel_operation.{input.ctx.service}.{input.ctx.operation}", input) + ) + return await super().execute_nexus_operation_cancel(input) + + +@workflow.defn +class ExpectCancelNexusWorkflow: + @workflow.run + async def run(self, input: str): + try: + await asyncio.wait_for(asyncio.Future(), 2) + except asyncio.TimeoutError: + raise ApplicationError("expected cancellation") + + +@nexusrpc.handler.service_handler +class InterceptedNexusService: + @nexus.workflow_run_operation + async def intercepted_operation( + self, ctx: nexus.WorkflowRunOperationContext, input: str + ) -> nexus.WorkflowHandle[None]: + return await ctx.start_workflow( + ExpectCancelNexusWorkflow.run, + input, + id=f"wf-{uuid.uuid4()}-{ctx.request_id}", + ) + + @activity.defn async def intercepted_activity(param: str) -> str: if not activity.info().is_local: @@ -177,20 +230,18 @@ async def run(self, style: str) -> None: nexus_client = workflow.create_nexus_client( endpoint=make_nexus_endpoint_name(workflow.info().task_queue), - service="non-existent-nexus-service", + service=InterceptedNexusService, ) + + nexus_handle = await nexus_client.start_operation( + operation=InterceptedNexusService.intercepted_operation, + input="nexus-workflow", + ) + nexus_handle.cancel() + try: - await nexus_client.start_operation( - operation="non-existent-nexus-operation", - input={"test": "data"}, - schedule_to_close_timeout=timedelta(microseconds=1), - ) - raise Exception("unreachable") + await nexus_handle except NexusOperationError: - # The test requires only that the workflow attempts to schedule the nexus operation. - # Instead of setting up a nexus service, we deliberately schedule a call to a - # non-existent nexus operation with an insufficiently long timeout, and expect this - # error. pass await self.finish.wait() @@ -230,9 +281,10 @@ async def test_worker_interceptor(client: Client, env: WorkflowEnvironment): async with Worker( client, task_queue=task_queue, - workflows=[InterceptedWorkflow], + workflows=[InterceptedWorkflow, ExpectCancelNexusWorkflow], activities=[intercepted_activity], interceptors=[TracingWorkerInterceptor()], + nexus_service_handlers=[InterceptedNexusService()], ): # Run workflow handle = await client.start_workflow( @@ -311,6 +363,14 @@ def pop_trace(name: str, filter: Callable[[Any], bool] | None = None) -> Any: assert pop_trace( "workflow.update.validator", lambda v: v.args[0] == "reject-me" ) + assert pop_trace( + "nexus.start_operation.InterceptedNexusService.intercepted_operation", + lambda v: v.input == "nexus-workflow", + ) + assert pop_trace("workflow.execute", lambda v: v.args[0] == "nexus-workflow") + assert pop_trace( + "nexus.cancel_operation.InterceptedNexusService.intercepted_operation", + ) # Confirm no unexpected traces assert not interceptor_traces diff --git a/uv.lock b/uv.lock index 7d8113d22..600e3e8fc 100644 --- a/uv.lock +++ b/uv.lock @@ -1760,14 +1760,14 @@ wheels = [ [[package]] name = "nexus-rpc" -version = "1.2.0" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/50/95d7bc91f900da5e22662c82d9bf0f72a4b01f2a552708bf2f43807707a1/nexus_rpc-1.2.0.tar.gz", hash = "sha256:b4ddaffa4d3996aaeadf49b80dfcdfbca48fe4cb616defaf3b3c5c2c8fc61890", size = 74142, upload-time = "2025-11-17T19:17:06.798Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/f2/d54f5c03d8f4672ccc0875787a385f53dcb61f98a8ae594b5620e85b9cb3/nexus_rpc-1.3.0.tar.gz", hash = "sha256:e56d3b57b60d707ce7a72f83f23f106b86eca1043aa658e44582ab5ff30ab9ad", size = 75650, upload-time = "2025-12-08T22:59:13.002Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/13/04/eaac430d0e6bf21265ae989427d37e94be5e41dc216879f1fbb6c5339942/nexus_rpc-1.2.0-py3-none-any.whl", hash = "sha256:977876f3af811ad1a09b2961d3d1ac9233bda43ff0febbb0c9906483b9d9f8a3", size = 28166, upload-time = "2025-11-17T19:17:05.64Z" }, + { url = "https://files.pythonhosted.org/packages/d6/74/0afd841de3199c148146c1d43b4bfb5605b2f1dc4c9a9087fe395091ea5a/nexus_rpc-1.3.0-py3-none-any.whl", hash = "sha256:aee0707b4861b22d8124ecb3f27d62dafbe8777dc50c66c91e49c006f971b92d", size = 28873, upload-time = "2025-12-08T22:59:12.024Z" }, ] [[package]] @@ -3021,7 +3021,7 @@ dev = [ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, - { name = "nexus-rpc", specifier = "==1.2.0" }, + { name = "nexus-rpc", specifier = "==1.3.0" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.5" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" },