From 8baadc813e9c210659f7a83f615616f80034dae8 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Mon, 29 Apr 2024 15:04:59 +0100 Subject: [PATCH 1/4] Add assistants stream --- run.go | 150 +++++++++++++++++++++++++++++++++++++++++++++++ run_test.go | 15 +++++ stream_reader.go | 2 +- 3 files changed, 166 insertions(+), 1 deletion(-) diff --git a/run.go b/run.go index 094b0a4db..687470cb8 100644 --- a/run.go +++ b/run.go @@ -124,6 +124,11 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +type RunRequestStreaming struct { + RunRequest + Stream bool `json:"stream"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } @@ -149,6 +154,11 @@ type CreateThreadAndRunRequest struct { Thread ThreadRequest `json:"thread"` } +type CreateThreadAndStreamRequest struct { + CreateThreadAndRunRequest + Stream bool `json:"stream"` +} + type RunStep struct { ID string `json:"id"` Object string `json:"object"` @@ -337,6 +347,43 @@ func (c *Client) SubmitToolOutputs( return } +type SubmitToolOutputsStreamRequest struct { + SubmitToolOutputsRequest + Stream bool `json:"stream"` +} + +func (c *Client) SubmitToolOutputsStream( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest, +) (stream *AssistantStream, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + r := SubmitToolOutputsStreamRequest{ + SubmitToolOutputsRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + // CancelRun cancels a run. func (c *Client) CancelRun( ctx context.Context, @@ -375,6 +422,109 @@ func (c *Client) CreateThreadAndRun( return } +type StreamMessageDelta struct { + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIDs []string `json:"file_ids"` +} + +type AssistantStreamEvent struct { + ID string `json:"id"` + Object string `json:"object"` + Delta StreamMessageDelta `json:"delta,omitempty"` + + // Run + CreatedAt int64 `json:"created_at,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + AssistantID string `json:"assistant_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Usage Usage `json:"usage,omitempty"` + + // ThreadMessage.Completed + Role string `json:"role,omitempty"` + Content []MessageContent `json:"content,omitempty"` + // IncompleteDetails + // IncompleteAt + + // Run steps + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + StepDetails StepDetails `json:"step_details"` + ExpiredAt *int64 `json:"expired_at,omitempty"` +} + +type AssistantStream struct { + *streamReader[AssistantStreamEvent] +} + +func (c *Client) CreateThreadAndStream(ctx context.Context, request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { + urlSuffix := "/threads/runs" + sr := CreateThreadAndStreamRequest{ + CreateThreadAndRunRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(sr), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + +func (c *Client) CreateRunStreaming(ctx context.Context, threadID string, request RunRequest) (stream *AssistantStream, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + + r := RunRequestStreaming{ + RunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + // RetrieveRunStep retrieves a run step. func (c *Client) RetrieveRunStep( ctx context.Context, diff --git a/run_test.go b/run_test.go index cdf99db05..f06c1564b 100644 --- a/run_test.go +++ b/run_test.go @@ -219,6 +219,21 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndRun error") + _, err = client.CreateThreadAndStream(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndStream error") + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) checks.NoError(t, err, "RetrieveRunStep error") diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..433548794 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -16,7 +16,7 @@ var ( ) type streamable interface { - ChatCompletionStreamResponse | CompletionResponse + ChatCompletionStreamResponse | CompletionResponse | AssistantStreamEvent } type streamReader[T streamable] struct { From 1870579b6933ea9966bed7a1952eb70ebc31c456 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Wed, 8 May 2024 10:01:35 +0100 Subject: [PATCH 2/4] Lint --- run.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/run.go b/run.go index 687470cb8..0db2ec1d2 100644 --- a/run.go +++ b/run.go @@ -469,7 +469,9 @@ type AssistantStream struct { *streamReader[AssistantStreamEvent] } -func (c *Client) CreateThreadAndStream(ctx context.Context, request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { +func (c *Client) CreateThreadAndStream( + ctx context.Context, + request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { urlSuffix := "/threads/runs" sr := CreateThreadAndStreamRequest{ CreateThreadAndRunRequest: request, @@ -496,7 +498,10 @@ func (c *Client) CreateThreadAndStream(ctx context.Context, request CreateThread return } -func (c *Client) CreateRunStreaming(ctx context.Context, threadID string, request RunRequest) (stream *AssistantStream, err error) { +func (c *Client) CreateRunStreaming( + ctx context.Context, + threadID string, + request RunRequest) (stream *AssistantStream, err error) { urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) r := RunRequestStreaming{ From 80aedac6610cb8600963a0b0900da85e73630ed0 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Mon, 13 May 2024 10:05:48 +0100 Subject: [PATCH 3/4] Add basic tests --- run_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/run_test.go b/run_test.go index f06c1564b..f3445852e 100644 --- a/run_test.go +++ b/run_test.go @@ -234,6 +234,16 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndStream error") + _, err = client.CreateRunStreaming(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRunStreaming error") + + _, err = client.SubmitToolOutputsStream(ctx, threadID, runID, openai.SubmitToolOutputsRequest{ + ToolOutputs: nil, + }) + checks.NoError(t, err, "SubmitToolOutputsStream error") + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) checks.NoError(t, err, "RetrieveRunStep error") From 72a39f449f960a1a724fbfe02189f53048576c68 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Wed, 15 May 2024 14:51:13 +0100 Subject: [PATCH 4/4] Add tool choice --- run.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/run.go b/run.go index 0db2ec1d2..8e64560ec 100644 --- a/run.go +++ b/run.go @@ -103,6 +103,9 @@ type RunRequest struct { // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // Used to force a tool call, only available in v2 API + ToolChoice any `json:"tool_choice,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread.