Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type systemDatabase interface {
insertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error)
listWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error)
updateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error
awaitWorkflowResult(ctx context.Context, workflowID string) (*string, error)
awaitWorkflowResult(ctx context.Context, workflowID string, pollInterval time.Duration) (*string, error)
cancelWorkflow(ctx context.Context, workflowID string) error
cancelAllBefore(ctx context.Context, cutoffTime time.Time) error
resumeWorkflow(ctx context.Context, workflowID string) error
Expand Down Expand Up @@ -1188,9 +1188,12 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st
return forkedWorkflowID, nil
}

func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (*string, error) {
func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string, pollInterval time.Duration) (*string, error) {
query := fmt.Sprintf(`SELECT status, output, error FROM %s.workflow_status WHERE workflow_uuid = $1`, pgx.Identifier{s.schema}.Sanitize())
var status WorkflowStatusType
if pollInterval <= 0 {
pollInterval = _DB_RETRY_INTERVAL
}
for {
select {
case <-ctx.Done():
Expand All @@ -1204,7 +1207,7 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (*st
err := row.Scan(&status, &outputString, &errorStr)
if err != nil {
if err == pgx.ErrNoRows {
time.Sleep(_DB_RETRY_INTERVAL)
time.Sleep(pollInterval)
continue
}
return nil, fmt.Errorf("failed to query workflow status: %w", err)
Expand All @@ -1219,7 +1222,7 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (*st
case WorkflowStatusCancelled:
return outputString, newAwaitedWorkflowCancelledError(workflowID)
default:
time.Sleep(_DB_RETRY_INTERVAL)
time.Sleep(pollInterval)
}
}
}
Expand Down
25 changes: 20 additions & 5 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@ type GetResultOption func(*getResultOptions)

// getResultOptions holds the configuration for GetResult execution.
type getResultOptions struct {
timeout time.Duration
timeout time.Duration
pollInterval time.Duration
}

func defaultGetResultOptions() *getResultOptions {
return &getResultOptions{pollInterval: _DB_RETRY_INTERVAL}
}

// WithHandleTimeout sets a timeout for the GetResult operation.
Expand All @@ -115,6 +120,16 @@ func WithHandleTimeout(timeout time.Duration) GetResultOption {
}
}

// WithHandlePollingInterval sets the polling interval for awaiting workflow completion in GetResult.
// If a non-positive interval is provided, the default interval is used.
func WithHandlePollingInterval(interval time.Duration) GetResultOption {
return func(opts *getResultOptions) {
if interval > 0 {
opts.pollInterval = interval
}
}
}

// GetStatus returns the current status of the workflow from the database
// If the DBOSContext is running in client mode, do not load input and outputs
func (h *baseWorkflowHandle) GetStatus() (WorkflowStatus, error) {
Expand Down Expand Up @@ -186,7 +201,7 @@ type workflowHandle[R any] struct {
}

func (h *workflowHandle[R]) GetResult(opts ...GetResultOption) (R, error) {
options := &getResultOptions{}
options := defaultGetResultOptions()
for _, opt := range opts {
opt(options)
}
Expand Down Expand Up @@ -253,7 +268,7 @@ type workflowPollingHandle[R any] struct {
}

func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error) {
options := &getResultOptions{}
options := defaultGetResultOptions()
for _, opt := range opts {
opt(options)
}
Expand All @@ -269,7 +284,7 @@ func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error)
}

encodedResult, err := retryWithResult(ctx, func() (any, error) {
return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID)
return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID, options.pollInterval)
}, withRetrierLogger(h.dbosContext.(*dbosContext).logger))

completedTime := time.Now()
Expand Down Expand Up @@ -1051,7 +1066,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
c.logger.Warn("Workflow ID conflict detected. Waiting for existing workflow to complete", "workflow_id", workflowID)
var encodedResult any
encodedResult, err = retryWithResult(c, func() (any, error) {
return c.systemDB.awaitWorkflowResult(uncancellableCtx, workflowID)
return c.systemDB.awaitWorkflowResult(uncancellableCtx, workflowID, _DB_RETRY_INTERVAL)
}, withRetrierLogger(c.logger))
// Keep the encoded result - decoding will happen in RunWorkflow[P,R] when we know the target type
outcomeChan <- workflowOutcome[any]{result: encodedResult, err: err, needsDecoding: true}
Expand Down
6 changes: 3 additions & 3 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4543,7 +4543,7 @@ func TestWorkflowIdentity(t *testing.T) {
})
}

func TestWorkflowHandleTimeout(t *testing.T) {
func TestWorkflowHandles(t *testing.T) {
dbosCtx := setupDBOS(t, true, true)
RegisterWorkflow(dbosCtx, slowWorkflow)

Expand All @@ -4552,7 +4552,7 @@ func TestWorkflowHandleTimeout(t *testing.T) {
require.NoError(t, err, "failed to start workflow")

start := time.Now()
_, err = handle.GetResult(WithHandleTimeout(10 * time.Millisecond))
_, err = handle.GetResult(WithHandleTimeout(10*time.Millisecond), WithHandlePollingInterval(1*time.Millisecond))
duration := time.Since(start)

require.Error(t, err, "expected timeout error")
Expand All @@ -4573,7 +4573,7 @@ func TestWorkflowHandleTimeout(t *testing.T) {
_, ok := pollingHandle.(*workflowPollingHandle[string])
require.True(t, ok, "expected polling handle, got %T", pollingHandle)

_, err = pollingHandle.GetResult(WithHandleTimeout(10 * time.Millisecond))
_, err = pollingHandle.GetResult(WithHandleTimeout(10*time.Millisecond), WithHandlePollingInterval(1*time.Millisecond))

require.Error(t, err, "expected timeout error")
assert.True(t, errors.Is(err, context.DeadlineExceeded),
Expand Down