Skip to content

Commit 7882e8d

Browse files
committed
feat: implement checkpoint response handling for all operations
Implement double-check pattern across all operation types to handle synchronous checkpoint responses, preventing invalid state transitions and unnecessary suspensions. Bug fix: Callback operations now defer errors to Callback.result() instead of raising immediately in create_callback(), ensuring deterministic replay when code executes between callback creation and result retrieval. Changes: - Add OperationExecutor base class with CheckResult for status checking - Implement double-check pattern: check status before and after checkpoint - Use is_sync parameter to control checkpoint synchronization behavior - Refactor all operations to use executor pattern: * StepOperationExecutor: sync for AT_MOST_ONCE, async for AT_LEAST_ONCE * InvokeOperationExecutor: sync checkpoint, always suspends * WaitOperationExecutor: sync checkpoint, suspends if not complete * CallbackOperationExecutor: sync checkpoint, defers errors to result() * WaitForConditionOperationExecutor: async checkpoint, no second check * ChildOperationExecutor: async checkpoint, handles large payloads - Remove inline while loops, centralize logic in base class - Update all tests to expect double checkpoint checks with side_effect mocks Affected modules: - operation/base.py: New OperationExecutor and CheckResult classes - operation/step.py: StepOperationExecutor implementation - operation/invoke.py: InvokeOperationExecutor implementation - operation/wait.py: WaitOperationExecutor implementation - operation/callback.py: CallbackOperationExecutor with deferred errors - operation/wait_for_condition.py: WaitForConditionOperationExecutor - operation/child.py: ChildOperationExecutor with ReplayChildren support - All operation tests: Updated mocks for double-check pattern
1 parent 6202b53 commit 7882e8d

File tree

18 files changed

+5153
-938
lines changed

18 files changed

+5153
-938
lines changed

src/aws_durable_execution_sdk_python/context.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@
2424
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
2525
from aws_durable_execution_sdk_python.logger import Logger, LogInfo
2626
from aws_durable_execution_sdk_python.operation.callback import (
27-
create_callback_handler,
27+
CallbackOperationExecutor,
2828
wait_for_callback_handler,
2929
)
3030
from aws_durable_execution_sdk_python.operation.child import child_handler
31-
from aws_durable_execution_sdk_python.operation.invoke import invoke_handler
31+
from aws_durable_execution_sdk_python.operation.invoke import InvokeOperationExecutor
3232
from aws_durable_execution_sdk_python.operation.map import map_handler
3333
from aws_durable_execution_sdk_python.operation.parallel import parallel_handler
34-
from aws_durable_execution_sdk_python.operation.step import step_handler
35-
from aws_durable_execution_sdk_python.operation.wait import wait_handler
34+
from aws_durable_execution_sdk_python.operation.step import StepOperationExecutor
35+
from aws_durable_execution_sdk_python.operation.wait import WaitOperationExecutor
3636
from aws_durable_execution_sdk_python.operation.wait_for_condition import (
37-
wait_for_condition_handler,
37+
WaitForConditionOperationExecutor,
3838
)
3939
from aws_durable_execution_sdk_python.serdes import (
4040
PassThroughSerDes,
@@ -323,13 +323,14 @@ def create_callback(
323323
if not config:
324324
config = CallbackConfig()
325325
operation_id: str = self._create_step_id()
326-
callback_id: str = create_callback_handler(
326+
executor: CallbackOperationExecutor = CallbackOperationExecutor(
327327
state=self.state,
328328
operation_identifier=OperationIdentifier(
329329
operation_id=operation_id, parent_id=self._parent_id, name=name
330330
),
331331
config=config,
332332
)
333+
callback_id: str = executor.process()
333334
result: Callback = Callback(
334335
callback_id=callback_id,
335336
operation_id=operation_id,
@@ -357,8 +358,10 @@ def invoke(
357358
Returns:
358359
The result of the invoked function
359360
"""
361+
if not config:
362+
config = InvokeConfig[P, R]()
360363
operation_id = self._create_step_id()
361-
result: R = invoke_handler(
364+
executor: InvokeOperationExecutor[R] = InvokeOperationExecutor(
362365
function_name=function_name,
363366
payload=payload,
364367
state=self.state,
@@ -369,6 +372,7 @@ def invoke(
369372
),
370373
config=config,
371374
)
375+
result: R = executor.process()
372376
self.state.track_replay(operation_id=operation_id)
373377
return result
374378

@@ -505,8 +509,10 @@ def step(
505509
) -> T:
506510
step_name = self._resolve_step_name(name, func)
507511
logger.debug("Step name: %s", step_name)
512+
if not config:
513+
config = StepConfig()
508514
operation_id = self._create_step_id()
509-
result: T = step_handler(
515+
executor: StepOperationExecutor[T] = StepOperationExecutor(
510516
func=func,
511517
config=config,
512518
state=self.state,
@@ -517,6 +523,7 @@ def step(
517523
),
518524
context_logger=self.logger,
519525
)
526+
result: T = executor.process()
520527
self.state.track_replay(operation_id=operation_id)
521528
return result
522529

@@ -532,15 +539,17 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
532539
msg = "duration must be at least 1 second"
533540
raise ValidationError(msg)
534541
operation_id = self._create_step_id()
535-
wait_handler(
536-
seconds=seconds,
542+
wait_seconds = duration.seconds
543+
executor: WaitOperationExecutor = WaitOperationExecutor(
544+
seconds=wait_seconds,
537545
state=self.state,
538546
operation_identifier=OperationIdentifier(
539547
operation_id=operation_id,
540548
parent_id=self._parent_id,
541549
name=name,
542550
),
543551
)
552+
executor.process()
544553
self.state.track_replay(operation_id=operation_id)
545554

546555
def wait_for_callback(
@@ -584,17 +593,20 @@ def wait_for_condition(
584593
raise ValidationError(msg)
585594

586595
operation_id = self._create_step_id()
587-
result: T = wait_for_condition_handler(
588-
check=check,
589-
config=config,
590-
state=self.state,
591-
operation_identifier=OperationIdentifier(
592-
operation_id=operation_id,
593-
parent_id=self._parent_id,
594-
name=name,
595-
),
596-
context_logger=self.logger,
596+
executor: WaitForConditionOperationExecutor[T] = (
597+
WaitForConditionOperationExecutor(
598+
check=check,
599+
config=config,
600+
state=self.state,
601+
operation_identifier=OperationIdentifier(
602+
operation_id=operation_id,
603+
parent_id=self._parent_id,
604+
name=name,
605+
),
606+
context_logger=self.logger,
607+
)
597608
)
609+
result: T = executor.process()
598610
self.state.track_replay(operation_id=operation_id)
599611
return result
600612

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""Base classes for operation executors with checkpoint response handling."""
2+
3+
from __future__ import annotations
4+
5+
from abc import ABC, abstractmethod
6+
from dataclasses import dataclass
7+
from typing import TYPE_CHECKING, Generic, TypeVar
8+
9+
from aws_durable_execution_sdk_python.exceptions import InvalidStateError
10+
11+
if TYPE_CHECKING:
12+
from aws_durable_execution_sdk_python.state import CheckpointedResult
13+
14+
T = TypeVar("T")
15+
16+
17+
@dataclass(frozen=True)
18+
class CheckResult(Generic[T]):
19+
"""Result of checking operation checkpoint status.
20+
21+
Encapsulates the outcome of checking an operation's status and determines
22+
the next action in the operation execution flow.
23+
24+
IMPORTANT: Do not construct directly. Use factory methods:
25+
- create_is_ready_to_execute(checkpoint) - operation ready to execute
26+
- create_started() - checkpoint created, check status again
27+
- create_completed(result) - terminal result available
28+
29+
Attributes:
30+
is_ready_to_execute: True if the operation is ready to execute its logic
31+
has_checkpointed_result: True if a terminal result is already available
32+
checkpointed_result: Checkpoint data for execute()
33+
deserialized_result: Final result when operation completed
34+
"""
35+
36+
is_ready_to_execute: bool
37+
has_checkpointed_result: bool
38+
checkpointed_result: CheckpointedResult | None = None
39+
deserialized_result: T | None = None
40+
41+
@classmethod
42+
def create_is_ready_to_execute(
43+
cls, checkpoint: CheckpointedResult
44+
) -> CheckResult[T]:
45+
"""Create a CheckResult indicating the operation is ready to execute.
46+
47+
Args:
48+
checkpoint: The checkpoint data to pass to execute()
49+
50+
Returns:
51+
CheckResult with is_ready_to_execute=True
52+
"""
53+
return cls(
54+
is_ready_to_execute=True,
55+
has_checkpointed_result=False,
56+
checkpointed_result=checkpoint,
57+
)
58+
59+
@classmethod
60+
def create_started(cls) -> CheckResult[T]:
61+
"""Create a CheckResult signaling that a checkpoint was created.
62+
63+
Signals that process() should verify checkpoint status again to detect
64+
if the operation completed already during checkpoint creation.
65+
66+
Returns:
67+
CheckResult indicating process() should check status again
68+
"""
69+
return cls(is_ready_to_execute=False, has_checkpointed_result=False)
70+
71+
@classmethod
72+
def create_completed(cls, result: T) -> CheckResult[T]:
73+
"""Create a CheckResult with a terminal result already deserialized.
74+
75+
Args:
76+
result: The final deserialized result
77+
78+
Returns:
79+
CheckResult with has_checkpointed_result=True and deserialized_result set
80+
"""
81+
return cls(
82+
is_ready_to_execute=False,
83+
has_checkpointed_result=True,
84+
deserialized_result=result,
85+
)
86+
87+
88+
class OperationExecutor(ABC, Generic[T]):
89+
"""Base class for durable operations with checkpoint response handling.
90+
91+
Provides a framework for implementing operations that check status after
92+
creating START checkpoints to handle synchronous completion, avoiding
93+
unnecessary execution or suspension.
94+
95+
The common pattern:
96+
1. Check operation status
97+
2. Create START checkpoint if needed
98+
3. Check status again (detects synchronous completion)
99+
4. Execute operation logic when ready
100+
101+
Subclasses must implement:
102+
- check_result_status(): Check status, create checkpoint if needed, return next action
103+
- execute(): Execute the operation logic with checkpoint data
104+
"""
105+
106+
@abstractmethod
107+
def check_result_status(self) -> CheckResult[T]:
108+
"""Check operation status and create START checkpoint if needed.
109+
110+
Called twice by process() when creating synchronous checkpoints: once before
111+
and once after, to detect if the operation completed immediately.
112+
113+
This method should:
114+
1. Get the current checkpoint result
115+
2. Check for terminal statuses (SUCCEEDED, FAILED, etc.) and handle them
116+
3. Check for pending statuses and suspend if needed
117+
4. Create a START checkpoint if the operation hasn't started
118+
5. Return a CheckResult indicating the next action
119+
120+
Returns:
121+
CheckResult indicating whether to:
122+
- Return a terminal result (has_checkpointed_result=True)
123+
- Execute operation logic (is_ready_to_execute=True)
124+
- Check status again (neither flag set - checkpoint was just created)
125+
126+
Raises:
127+
Operation-specific exceptions for terminal failure states
128+
SuspendExecution for pending states
129+
"""
130+
... # pragma: no cover
131+
132+
@abstractmethod
133+
def execute(self, checkpointed_result: CheckpointedResult) -> T:
134+
"""Execute operation logic with checkpoint data.
135+
136+
This method is called when the operation is ready to execute its core logic.
137+
It receives the checkpoint data that was returned by check_result_status().
138+
139+
Args:
140+
checkpointed_result: The checkpoint data containing operation state
141+
142+
Returns:
143+
The result of executing the operation
144+
145+
Raises:
146+
May raise operation-specific errors during execution
147+
"""
148+
... # pragma: no cover
149+
150+
def process(self) -> T:
151+
"""Process operation with checkpoint response handling.
152+
153+
Orchestrates the double-check pattern:
154+
1. Check status (handles replay and existing checkpoints)
155+
2. If checkpoint was just created, check status again (detects synchronous completion)
156+
3. Return terminal result if available
157+
4. Execute operation logic if ready
158+
5. Raise error for invalid states
159+
160+
Returns:
161+
The final result of the operation
162+
163+
Raises:
164+
InvalidStateError: If the check result is in an invalid state
165+
May raise operation-specific errors from check_result_status() or execute()
166+
"""
167+
# Check 1: Entry (handles replay and existing checkpoints)
168+
result = self.check_result_status()
169+
170+
# If checkpoint was created, verify checkpoint response for immediate status change
171+
if not result.is_ready_to_execute and not result.has_checkpointed_result:
172+
result = self.check_result_status()
173+
174+
# Return terminal result if available (can be None for operations that return None)
175+
if result.has_checkpointed_result:
176+
return result.deserialized_result # type: ignore[return-value]
177+
178+
# Execute operation logic
179+
if result.is_ready_to_execute:
180+
if result.checkpointed_result is None:
181+
msg = "CheckResult is marked ready to execute but checkpointed result is not set."
182+
raise InvalidStateError(msg)
183+
return self.execute(result.checkpointed_result)
184+
185+
# Invalid state - neither terminal nor ready to execute
186+
msg = "Invalid CheckResult state: neither terminal nor ready to execute"
187+
raise InvalidStateError(msg)

0 commit comments

Comments
 (0)