diff --git a/run.go b/run.go index 9c51aaf8d..665d9a897 100644 --- a/run.go +++ b/run.go @@ -134,6 +134,11 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +type RunRequestStreaming struct { + RunRequest + Stream bool `json:"stream"` +} + // ReponseFormat specifies the format the model must output. // https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. // Type can either be text or json_object. @@ -166,6 +171,11 @@ type CreateThreadAndRunRequest struct { Thread ThreadRequest `json:"thread"` } +type CreateThreadAndStreamRequest struct { + CreateThreadAndRunRequest + Stream bool `json:"stream"` +} + type RunStep struct { ID string `json:"id"` Object string `json:"object"` @@ -354,6 +364,43 @@ func (c *Client) SubmitToolOutputs( return } +type SubmitToolOutputsStreamRequest struct { + SubmitToolOutputsRequest + Stream bool `json:"stream"` +} + +func (c *Client) SubmitToolOutputsStream( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest, +) (stream *AssistantStream, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + r := SubmitToolOutputsStreamRequest{ + SubmitToolOutputsRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + // CancelRun cancels a run. func (c *Client) CancelRun( ctx context.Context, @@ -392,6 +439,114 @@ func (c *Client) CreateThreadAndRun( return } +type StreamMessageDelta struct { + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIDs []string `json:"file_ids"` +} + +type AssistantStreamEvent struct { + ID string `json:"id"` + Object string `json:"object"` + Delta StreamMessageDelta `json:"delta,omitempty"` + + // Run + CreatedAt int64 `json:"created_at,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + AssistantID string `json:"assistant_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Usage Usage `json:"usage,omitempty"` + + // ThreadMessage.Completed + Role string `json:"role,omitempty"` + Content []MessageContent `json:"content,omitempty"` + // IncompleteDetails + // IncompleteAt + + // Run steps + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + StepDetails StepDetails `json:"step_details"` + ExpiredAt *int64 `json:"expired_at,omitempty"` +} + +type AssistantStream struct { + *streamReader[AssistantStreamEvent] +} + +func (c *Client) CreateThreadAndStream( + ctx context.Context, + request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { + urlSuffix := "/threads/runs" + sr := CreateThreadAndStreamRequest{ + CreateThreadAndRunRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(sr), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + +func (c *Client) CreateRunStreaming( + ctx context.Context, + threadID string, + request RunRequest) (stream *AssistantStream, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + + r := RunRequestStreaming{ + RunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + // RetrieveRunStep retrieves a run step. func (c *Client) RetrieveRunStep( ctx context.Context, diff --git a/run_test.go b/run_test.go index cdf99db05..f3445852e 100644 --- a/run_test.go +++ b/run_test.go @@ -219,6 +219,31 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndRun error") + _, err = client.CreateThreadAndStream(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndStream error") + + _, err = client.CreateRunStreaming(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRunStreaming error") + + _, err = client.SubmitToolOutputsStream(ctx, threadID, runID, openai.SubmitToolOutputsRequest{ + ToolOutputs: nil, + }) + checks.NoError(t, err, "SubmitToolOutputsStream error") + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) checks.NoError(t, err, "RetrieveRunStep error") diff --git a/stream_reader.go b/stream_reader.go index 6faefe0a7..792806b65 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -17,7 +17,7 @@ var ( ) type streamable interface { - ChatCompletionStreamResponse | CompletionResponse + ChatCompletionStreamResponse | CompletionResponse | AssistantStreamEvent } type streamReader[T streamable] struct {